at master 2.1 kB view raw
1{ 2 lib, 3 buildPythonPackage, 4 fetchFromGitHub, 5 6 # build-system 7 setuptools-scm, 8 9 # dependencies 10 fastprogress, 11 jax, 12 jaxlib, 13 jaxopt, 14 optax, 15 typing-extensions, 16 17 # checks 18 pytestCheckHook, 19 pytest-xdist, 20 21 stdenv, 22}: 23 24buildPythonPackage rec { 25 pname = "blackjax"; 26 version = "1.2.5"; 27 pyproject = true; 28 29 src = fetchFromGitHub { 30 owner = "blackjax-devs"; 31 repo = "blackjax"; 32 tag = version; 33 hash = "sha256-2GTjKjLIWFaluTjdWdUF9Iim973y81xv715xspghRZI="; 34 }; 35 36 build-system = [ setuptools-scm ]; 37 38 dependencies = [ 39 fastprogress 40 jax 41 jaxlib 42 jaxopt 43 optax 44 typing-extensions 45 ]; 46 47 nativeCheckInputs = [ 48 pytestCheckHook 49 pytest-xdist 50 ]; 51 52 pytestFlags = [ 53 # DeprecationWarning: JAXopt is no longer maintained 54 "-Wignore::DeprecationWarning" 55 ]; 56 57 disabledTestPaths = [ 58 "tests/test_benchmarks.py" 59 60 # Assertion errors on numerical values 61 "tests/mcmc/test_integrators.py" 62 ]; 63 64 disabledTests = [ 65 # too slow 66 "test_adaptive_tempered_smc" 67 68 # AssertionError on numerical values 69 "test_barker" 70 "test_mclmc" 71 "test_mcse4" 72 "test_normal_univariate" 73 "test_nuts__with_device" 74 "test_nuts__with_jit" 75 "test_nuts__without_device" 76 "test_nuts__without_jit" 77 "test_smc_waste_free__with_jit" 78 79 # Numerical test (AssertionError) 80 # First report, when the failure was only happening on aarch64-linux: 81 # https://github.com/blackjax-devs/blackjax/issues/668 82 # Second report, when the test started happening on x86_64-linux too after Jax was updated to 0.7.0 83 # https://github.com/blackjax-devs/blackjax/issues/795 84 "test_chees_adaptation" 85 ]; 86 87 pythonImportsCheck = [ "blackjax" ]; 88 89 meta = { 90 homepage = "https://blackjax-devs.github.io/blackjax"; 91 description = "Sampling library designed for ease of use, speed and modularity"; 92 changelog = "https://github.com/blackjax-devs/blackjax/releases/tag/${version}"; 93 license = lib.licenses.asl20; 94 maintainers = with lib.maintainers; [ bcdarwin ]; 95 }; 96}