1{
2 jax,
3 pkgs,
4}:
5
6pkgs.writers.writePython3Bin "jax-test-cuda"
7 {
8 libraries = [
9 jax
10 ]
11 ++ jax.optional-dependencies.cuda;
12 }
13 ''
14 import jax
15 import jax.numpy as jnp
16 from jax import random
17 from jax.experimental import sparse
18
19 assert jax.devices()[0].platform == "gpu" # libcuda.so
20
21 rng = random.key(0) # libcudart.so, libcudnn.so
22 x = random.normal(rng, (100, 100))
23 x @ x # libcublas.so
24 jnp.fft.fft(x) # libcufft.so
25 jnp.linalg.inv(x) # libcusolver.so
26 sparse.CSR.fromdense(x) @ x # libcusparse.so
27
28 print("success!")
29 ''