at master 594 B view raw
1{ 2 jax, 3 pkgs, 4}: 5 6pkgs.writers.writePython3Bin "jax-test-cuda" 7 { 8 libraries = [ 9 jax 10 ] 11 ++ jax.optional-dependencies.cuda; 12 } 13 '' 14 import jax 15 import jax.numpy as jnp 16 from jax import random 17 from jax.experimental import sparse 18 19 assert jax.devices()[0].platform == "gpu" # libcuda.so 20 21 rng = random.key(0) # libcudart.so, libcudnn.so 22 x = random.normal(rng, (100, 100)) 23 x @ x # libcublas.so 24 jnp.fft.fft(x) # libcufft.so 25 jnp.linalg.inv(x) # libcusolver.so 26 sparse.CSR.fromdense(x) @ x # libcusparse.so 27 28 print("success!") 29 ''