at master 2.7 kB view raw
1{ 2 lib, 3 stdenv, 4 buildPythonPackage, 5 fetchFromGitHub, 6 7 # build-system 8 setuptools, 9 10 # dependencies 11 jax, 12 jaxlib, 13 multipledispatch, 14 numpy, 15 tqdm, 16 17 # tests 18 dm-haiku, 19 equinox, 20 flax, 21 funsor, 22 graphviz, 23 optax, 24 pyro-api, 25 pytest-xdist, 26 pytestCheckHook, 27 scikit-learn, 28 tensorflow-probability, 29}: 30 31buildPythonPackage rec { 32 pname = "numpyro"; 33 version = "0.19.0"; 34 pyproject = true; 35 36 src = fetchFromGitHub { 37 owner = "pyro-ppl"; 38 repo = "numpyro"; 39 tag = version; 40 hash = "sha256-3kzaINsz1Mjk97ERQsQIYIBz7CVmXtVDn0edJFMHQWs="; 41 }; 42 43 build-system = [ setuptools ]; 44 45 dependencies = [ 46 jax 47 jaxlib 48 multipledispatch 49 numpy 50 tqdm 51 ]; 52 53 nativeCheckInputs = [ 54 dm-haiku 55 equinox 56 flax 57 funsor 58 graphviz 59 optax 60 pyro-api 61 pytest-xdist 62 pytestCheckHook 63 scikit-learn 64 tensorflow-probability 65 ]; 66 67 pythonImportsCheck = [ "numpyro" ]; 68 69 pytestFlags = [ 70 # Tests memory consumption grows significantly with the number of parallel processes (reaches ~200GB with 80 jobs) 71 "--maxprocesses=8" 72 73 # A few tests fail with: 74 # UserWarning: There are not enough devices to run parallel chains: expected 2 but got 1. 75 # Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(2)` at the beginning of your program. 76 # You can double-check how many devices are available in your system using `jax.local_device_count()`. 77 "-Wignore::UserWarning" 78 ]; 79 80 disabledTests = [ 81 # AssertionError, assert GLOBAL["count"] == 4 (assert 5 == 4) 82 "test_mcmc_parallel_chain" 83 84 # AssertionError due to tolerance issues 85 "test_bijective_transforms" 86 "test_cpu" 87 "test_entropy_categorical" 88 "test_gaussian_model" 89 90 # > with pytest.warns(UserWarning, match="Hessian of log posterior"): 91 # E Failed: DID NOT WARN. No warnings of type (<class 'UserWarning'>,) were emitted. 92 # E Emitted warnings: []. 93 "test_laplace_approximation_warning" 94 95 # ValueError: compiling computation that requires 2 logical devices, but only 1 XLA devices are available (num_replicas=2) 96 "test_chain" 97 ] 98 ++ lib.optionals stdenv.hostPlatform.isDarwin [ 99 # AssertionError: Not equal to tolerance rtol=0.06, atol=0 100 "test_functional_map" 101 ]; 102 103 disabledTestPaths = [ 104 # Require internet access 105 "test/test_example_utils.py" 106 ]; 107 108 meta = { 109 description = "Library for probabilistic programming with NumPy"; 110 homepage = "https://num.pyro.ai/"; 111 changelog = "https://github.com/pyro-ppl/numpyro/releases/tag/${version}"; 112 license = lib.licenses.asl20; 113 maintainers = with lib.maintainers; [ fab ]; 114 }; 115}