1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5
6 # build-system
7 setuptools-scm,
8
9 # dependencies
10 fastprogress,
11 jax,
12 jaxlib,
13 jaxopt,
14 optax,
15 typing-extensions,
16
17 # checks
18 pytestCheckHook,
19 pytest-xdist,
20
21 stdenv,
22}:
23
24buildPythonPackage rec {
25 pname = "blackjax";
26 version = "1.2.5";
27 pyproject = true;
28
29 src = fetchFromGitHub {
30 owner = "blackjax-devs";
31 repo = "blackjax";
32 tag = version;
33 hash = "sha256-2GTjKjLIWFaluTjdWdUF9Iim973y81xv715xspghRZI=";
34 };
35
36 build-system = [ setuptools-scm ];
37
38 dependencies = [
39 fastprogress
40 jax
41 jaxlib
42 jaxopt
43 optax
44 typing-extensions
45 ];
46
47 nativeCheckInputs = [
48 pytestCheckHook
49 pytest-xdist
50 ];
51
52 pytestFlags = [
53 # DeprecationWarning: JAXopt is no longer maintained
54 "-Wignore::DeprecationWarning"
55 ];
56
57 disabledTestPaths = [
58 "tests/test_benchmarks.py"
59
60 # Assertion errors on numerical values
61 "tests/mcmc/test_integrators.py"
62 ];
63
64 disabledTests = [
65 # too slow
66 "test_adaptive_tempered_smc"
67
68 # AssertionError on numerical values
69 "test_barker"
70 "test_mclmc"
71 "test_mcse4"
72 "test_normal_univariate"
73 "test_nuts__with_device"
74 "test_nuts__with_jit"
75 "test_nuts__without_device"
76 "test_nuts__without_jit"
77 "test_smc_waste_free__with_jit"
78
79 # Numerical test (AssertionError)
80 # First report, when the failure was only happening on aarch64-linux:
81 # https://github.com/blackjax-devs/blackjax/issues/668
82 # Second report, when the test started happening on x86_64-linux too after Jax was updated to 0.7.0
83 # https://github.com/blackjax-devs/blackjax/issues/795
84 "test_chees_adaptation"
85 ];
86
87 pythonImportsCheck = [ "blackjax" ];
88
89 meta = {
90 homepage = "https://blackjax-devs.github.io/blackjax";
91 description = "Sampling library designed for ease of use, speed and modularity";
92 changelog = "https://github.com/blackjax-devs/blackjax/releases/tag/${version}";
93 license = lib.licenses.asl20;
94 maintainers = with lib.maintainers; [ bcdarwin ];
95 };
96}