1{
2 lib,
3 stdenv,
4 buildPythonPackage,
5 fetchFromGitHub,
6 rustPlatform,
7
8 # build-system
9 cargo,
10 rustc,
11
12 # dependencies
13 arviz,
14 pandas,
15 pyarrow,
16 xarray,
17
18 # tests
19 # bridgestan, (not packaged)
20 equinox,
21 flowjax,
22 jax,
23 jaxlib,
24 numba,
25 pytest-timeout,
26 pymc,
27 pytestCheckHook,
28 setuptools,
29 writableTmpDirAsHomeHook,
30}:
31
32buildPythonPackage rec {
33 pname = "nutpie";
34 version = "0.15.2";
35 pyproject = true;
36
37 src = fetchFromGitHub {
38 owner = "pymc-devs";
39 repo = "nutpie";
40 tag = "v${version}";
41 hash = "sha256-9rcQtEdaafMyuNb/ezcqUmrwXbQFa9hdajGAtANdHOw=";
42 };
43
44 cargoDeps = rustPlatform.fetchCargoVendor {
45 inherit pname version src;
46 hash = "sha256-6JWBJYGhSNUL8KYiEE2ZBW9xP4CmkCcwwhsO6aOvZyA=";
47 };
48
49 build-system = [
50 cargo
51 rustPlatform.bindgenHook
52 rustPlatform.cargoSetupHook
53 rustPlatform.maturinBuildHook
54 rustc
55 ];
56
57 pythonRelaxDeps = [
58 "xarray"
59 ];
60
61 dependencies = [
62 arviz
63 pandas
64 pyarrow
65 xarray
66 ];
67
68 pythonImportsCheck = [ "nutpie" ];
69
70 nativeCheckInputs = [
71 # bridgestan
72 equinox
73 flowjax
74 numba
75 jax
76 jaxlib
77 pymc
78 pytest-timeout
79 pytestCheckHook
80 setuptools
81 writableTmpDirAsHomeHook
82 ];
83
84 pytestFlags = [
85 "-v"
86 ];
87
88 disabledTests = lib.optionals (stdenv.hostPlatform.isLinux && stdenv.hostPlatform.isAarch64) [
89 # flaky (assert np.float64(0.0017554642626285276) > 0.01)
90 "test_normalizing_flow"
91 ];
92
93 disabledTestPaths = [
94 # Require unpackaged bridgestan
95 "tests/test_stan.py"
96 ];
97
98 meta = {
99 description = "Python wrapper for nuts-rs";
100 homepage = "https://github.com/pymc-devs/nutpie";
101 changelog = "https://github.com/pymc-devs/nutpie/blob/v${version}/CHANGELOG.md";
102 license = lib.licenses.mit;
103 maintainers = with lib.maintainers; [ GaetanLepage ];
104 };
105}