at master 1.8 kB view raw
1{ 2 lib, 3 stdenv, 4 buildPythonPackage, 5 fetchFromGitHub, 6 rustPlatform, 7 8 # build-system 9 cargo, 10 rustc, 11 12 # dependencies 13 arviz, 14 pandas, 15 pyarrow, 16 xarray, 17 18 # tests 19 # bridgestan, (not packaged) 20 equinox, 21 flowjax, 22 jax, 23 jaxlib, 24 numba, 25 pytest-timeout, 26 pymc, 27 pytestCheckHook, 28 setuptools, 29 writableTmpDirAsHomeHook, 30}: 31 32buildPythonPackage rec { 33 pname = "nutpie"; 34 version = "0.15.2"; 35 pyproject = true; 36 37 src = fetchFromGitHub { 38 owner = "pymc-devs"; 39 repo = "nutpie"; 40 tag = "v${version}"; 41 hash = "sha256-9rcQtEdaafMyuNb/ezcqUmrwXbQFa9hdajGAtANdHOw="; 42 }; 43 44 cargoDeps = rustPlatform.fetchCargoVendor { 45 inherit pname version src; 46 hash = "sha256-6JWBJYGhSNUL8KYiEE2ZBW9xP4CmkCcwwhsO6aOvZyA="; 47 }; 48 49 build-system = [ 50 cargo 51 rustPlatform.bindgenHook 52 rustPlatform.cargoSetupHook 53 rustPlatform.maturinBuildHook 54 rustc 55 ]; 56 57 pythonRelaxDeps = [ 58 "xarray" 59 ]; 60 61 dependencies = [ 62 arviz 63 pandas 64 pyarrow 65 xarray 66 ]; 67 68 pythonImportsCheck = [ "nutpie" ]; 69 70 nativeCheckInputs = [ 71 # bridgestan 72 equinox 73 flowjax 74 numba 75 jax 76 jaxlib 77 pymc 78 pytest-timeout 79 pytestCheckHook 80 setuptools 81 writableTmpDirAsHomeHook 82 ]; 83 84 pytestFlags = [ 85 "-v" 86 ]; 87 88 disabledTests = lib.optionals (stdenv.hostPlatform.isLinux && stdenv.hostPlatform.isAarch64) [ 89 # flaky (assert np.float64(0.0017554642626285276) > 0.01) 90 "test_normalizing_flow" 91 ]; 92 93 disabledTestPaths = [ 94 # Require unpackaged bridgestan 95 "tests/test_stan.py" 96 ]; 97 98 meta = { 99 description = "Python wrapper for nuts-rs"; 100 homepage = "https://github.com/pymc-devs/nutpie"; 101 changelog = "https://github.com/pymc-devs/nutpie/blob/v${version}/CHANGELOG.md"; 102 license = lib.licenses.mit; 103 maintainers = with lib.maintainers; [ GaetanLepage ]; 104 }; 105}