1{
2 lib,
3 stdenv,
4 buildPythonPackage,
5 fetchFromGitHub,
6
7 # build-system
8 setuptools,
9
10 # dependencies
11 jax,
12 jaxlib,
13 multipledispatch,
14 numpy,
15 tqdm,
16
17 # tests
18 dm-haiku,
19 equinox,
20 flax,
21 funsor,
22 graphviz,
23 optax,
24 pyro-api,
25 pytest-xdist,
26 pytestCheckHook,
27 scikit-learn,
28 tensorflow-probability,
29}:
30
31buildPythonPackage rec {
32 pname = "numpyro";
33 version = "0.19.0";
34 pyproject = true;
35
36 src = fetchFromGitHub {
37 owner = "pyro-ppl";
38 repo = "numpyro";
39 tag = version;
40 hash = "sha256-3kzaINsz1Mjk97ERQsQIYIBz7CVmXtVDn0edJFMHQWs=";
41 };
42
43 build-system = [ setuptools ];
44
45 dependencies = [
46 jax
47 jaxlib
48 multipledispatch
49 numpy
50 tqdm
51 ];
52
53 nativeCheckInputs = [
54 dm-haiku
55 equinox
56 flax
57 funsor
58 graphviz
59 optax
60 pyro-api
61 pytest-xdist
62 pytestCheckHook
63 scikit-learn
64 tensorflow-probability
65 ];
66
67 pythonImportsCheck = [ "numpyro" ];
68
69 pytestFlags = [
70 # Tests memory consumption grows significantly with the number of parallel processes (reaches ~200GB with 80 jobs)
71 "--maxprocesses=8"
72
73 # A few tests fail with:
74 # UserWarning: There are not enough devices to run parallel chains: expected 2 but got 1.
75 # Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(2)` at the beginning of your program.
76 # You can double-check how many devices are available in your system using `jax.local_device_count()`.
77 "-Wignore::UserWarning"
78 ];
79
80 disabledTests = [
81 # AssertionError, assert GLOBAL["count"] == 4 (assert 5 == 4)
82 "test_mcmc_parallel_chain"
83
84 # AssertionError due to tolerance issues
85 "test_bijective_transforms"
86 "test_cpu"
87 "test_entropy_categorical"
88 "test_gaussian_model"
89
90 # > with pytest.warns(UserWarning, match="Hessian of log posterior"):
91 # E Failed: DID NOT WARN. No warnings of type (<class 'UserWarning'>,) were emitted.
92 # E Emitted warnings: [].
93 "test_laplace_approximation_warning"
94
95 # ValueError: compiling computation that requires 2 logical devices, but only 1 XLA devices are available (num_replicas=2)
96 "test_chain"
97 ]
98 ++ lib.optionals stdenv.hostPlatform.isDarwin [
99 # AssertionError: Not equal to tolerance rtol=0.06, atol=0
100 "test_functional_map"
101 ];
102
103 disabledTestPaths = [
104 # Require internet access
105 "test/test_example_utils.py"
106 ];
107
108 meta = {
109 description = "Library for probabilistic programming with NumPy";
110 homepage = "https://num.pyro.ai/";
111 changelog = "https://github.com/pyro-ppl/numpyro/releases/tag/${version}";
112 license = lib.licenses.asl20;
113 maintainers = with lib.maintainers; [ fab ];
114 };
115}