at master 5.5 kB view raw
1{ 2 lib, 3 config, 4 stdenv, 5 blas, 6 lapack, 7 buildPythonPackage, 8 fetchFromGitHub, 9 cudaSupport ? config.cudaSupport, 10 11 # build-system 12 setuptools, 13 14 # dependencies 15 jaxlib, 16 ml-dtypes, 17 numpy, 18 opt-einsum, 19 scipy, 20 21 # optional-dependencies 22 jax-cuda12-plugin, 23 24 # tests 25 cloudpickle, 26 hypothesis, 27 matplotlib, 28 pytestCheckHook, 29 pytest-xdist, 30 31 # passthru 32 callPackage, 33 jax, 34 jaxlib-build, 35 jaxlib-bin, 36}: 37 38let 39 usingMKL = blas.implementation == "mkl" || lapack.implementation == "mkl"; 40in 41buildPythonPackage rec { 42 pname = "jax"; 43 version = "0.7.2"; 44 pyproject = true; 45 46 src = fetchFromGitHub { 47 owner = "google"; 48 repo = "jax"; 49 # google/jax contains tags for jax and jaxlib. Only use jax tags! 50 tag = "jax-v${version}"; 51 hash = "sha256-GBpHFjvF7SvxJafu7aVlTp0jxSo4jAi9oPeMg2B/P24="; 52 }; 53 54 build-system = [ setuptools ]; 55 56 # The version is automatically set to ".dev" if this variable is not set. 57 # https://github.com/google/jax/commit/e01f2617b85c5bdffc5ffb60b3d8d8ca9519a1f3 58 JAX_RELEASE = "1"; 59 60 dependencies = [ 61 jaxlib 62 ml-dtypes 63 numpy 64 opt-einsum 65 scipy 66 ] 67 ++ lib.optionals cudaSupport optional-dependencies.cuda; 68 69 optional-dependencies = rec { 70 cuda = [ jax-cuda12-plugin ]; 71 cuda12 = cuda; 72 cuda12_pip = cuda; 73 cuda12_local = cuda; 74 }; 75 76 nativeCheckInputs = [ 77 cloudpickle 78 hypothesis 79 matplotlib 80 pytestCheckHook 81 pytest-xdist 82 ]; 83 84 # high parallelism will result in the tests getting stuck 85 dontUsePytestXdist = true; 86 87 pytestFlags = [ 88 "--numprocesses=4" 89 "-Wignore::DeprecationWarning" 90 ]; 91 92 # NOTE: Don't run the tests in the experimental directory as they require flax 93 # which creates a circular dependency. See https://discourse.nixos.org/t/how-to-nix-ify-python-packages-with-circular-dependencies/14648/2. 94 # Not a big deal, this is how the JAX docs suggest running the test suite 95 # anyhow. 96 enabledTestPaths = [ 97 "tests/" 98 ]; 99 100 disabledTestPaths = lib.optionals stdenv.hostPlatform.isDarwin [ 101 # SystemError: nanobind::detail::nb_func_error_except(): exception could not be translated! 102 # reported at: https://github.com/jax-ml/jax/issues/26106 103 "tests/pjit_test.py::PJitErrorTest::testAxisResourcesMismatch" 104 "tests/shape_poly_test.py::ShapePolyTest" 105 "tests/tree_util_test.py::TreeTest" 106 107 # Mostly AssertionError on numerical tests failing since 0.7.0 108 # https://github.com/jax-ml/jax/issues/31428 109 "tests/export_back_compat_test.py" 110 "tests/lax_numpy_test.py" 111 "tests/lax_scipy_test.py" 112 "tests/lax_test.py" 113 "tests/linalg_test.py" 114 ]; 115 116 # Prevents `tests/export_back_compat_test.py::CompatTest::test_*` tests from failing on darwin with 117 # PermissionError: [Errno 13] Permission denied: '/tmp/back_compat_testdata/test_*.py' 118 # See https://github.com/google/jax/blob/jaxlib-v0.4.27/jax/_src/internal_test_util/export_back_compat_test_util.py#L240-L241 119 # NOTE: this doesn't seem to be an issue on linux 120 preCheck = lib.optionalString stdenv.hostPlatform.isDarwin '' 121 export TEST_UNDECLARED_OUTPUTS_DIR=$(mktemp -d) 122 ''; 123 124 disabledTests = [ 125 # Exceeds tolerance when the machine is busy 126 "test_custom_linear_solve_aux" 127 ] 128 ++ lib.optionals usingMKL [ 129 # See 130 # * https://github.com/google/jax/issues/9705 131 # * https://discourse.nixos.org/t/getting-different-results-for-the-same-build-on-two-equally-configured-machines/17921 132 # * https://github.com/NixOS/nixpkgs/issues/161960 133 "test_custom_linear_solve_cholesky" 134 "test_custom_root_with_aux" 135 "testEigvalsGrad_shape" 136 ] 137 ++ lib.optionals stdenv.hostPlatform.isAarch64 [ 138 # Fails on some hardware due to some numerical error 139 # See https://github.com/google/jax/issues/18535 140 "testQdwhWithOnRankDeficientInput5" 141 ] 142 ++ lib.optionals stdenv.hostPlatform.isDarwin [ 143 # SystemError: nanobind::detail::nb_func_error_except(): exception could not be translated! 144 # reported at: https://github.com/jax-ml/jax/issues/26106 145 "testInAxesPyTreePrefixMismatchError" 146 "testInAxesPyTreePrefixMismatchErrorKwargs" 147 "testOutAxesPyTreePrefixMismatchError" 148 "test_tree_map" 149 "test_tree_prefix_error" 150 "test_vjp_rule_inconsistent_pytree_structures_error" 151 "test_vmap_in_axes_tree_prefix_error" 152 "test_vmap_mismatched_axis_sizes_error_message_issue_705" 153 ]; 154 155 pythonImportsCheck = [ "jax" ]; 156 157 # Test CUDA-enabled jax and jaxlib. Running CUDA-enabled tests is not 158 # currently feasible within the nix build environment so we have to maintain 159 # this script separately. See https://github.com/NixOS/nixpkgs/pull/256230 160 # for a possible remedy to this situation. 161 # 162 # Run these tests with eg 163 # 164 # NIXPKGS_ALLOW_UNFREE=1 nixglhost -- nix run --impure .#python3Packages.jax.passthru.tests.test_cuda_jaxlibBin 165 passthru.tests = { 166 # jaxlib-build is broken as of 2024-12-20 167 # test_cuda_jaxlibSource = callPackage ./test-cuda.nix { 168 # jax = jax.override { jaxlib = jaxlib-build; }; 169 # }; 170 test_cuda_jaxlibBin = callPackage ./test-cuda.nix { 171 jax = jax.override { jaxlib = jaxlib-bin; }; 172 }; 173 }; 174 175 # updater fails to pick the correct branch 176 passthru.skipBulkUpdate = true; 177 178 meta = { 179 description = "Source-built JAX frontend: differentiate, compile, and transform Numpy code"; 180 homepage = "https://github.com/google/jax"; 181 license = lib.licenses.asl20; 182 maintainers = with lib.maintainers; [ 183 GaetanLepage 184 samuela 185 ]; 186 }; 187}