at master 2.8 kB view raw
1{ 2 lib, 3 buildPythonPackage, 4 fetchFromGitHub, 5 6 # build-system 7 setuptools, 8 setuptools-scm, 9 10 # dependencies 11 jax, 12 msgpack, 13 numpy, 14 optax, 15 orbax-checkpoint, 16 pyyaml, 17 rich, 18 tensorstore, 19 typing-extensions, 20 21 # optional-dependencies 22 matplotlib, 23 24 # tests 25 cloudpickle, 26 keras, 27 einops, 28 flaxlib, 29 pytestCheckHook, 30 pytest-xdist, 31 sphinx, 32 tensorflow, 33 treescope, 34 35 writeScript, 36 tomlq, 37}: 38 39buildPythonPackage rec { 40 pname = "flax"; 41 version = "0.12.0"; 42 pyproject = true; 43 44 src = fetchFromGitHub { 45 owner = "google"; 46 repo = "flax"; 47 tag = "v${version}"; 48 hash = "sha256-ioMj8+TuOFX3t9p3oVaywaOQPFBgvNcy7b/2WX/yvXA="; 49 }; 50 51 build-system = [ 52 setuptools 53 setuptools-scm 54 ]; 55 56 dependencies = [ 57 flaxlib 58 jax 59 msgpack 60 numpy 61 optax 62 orbax-checkpoint 63 pyyaml 64 rich 65 tensorstore 66 treescope 67 typing-extensions 68 ]; 69 70 optional-dependencies = { 71 all = [ matplotlib ]; 72 }; 73 74 pythonImportsCheck = [ "flax" ]; 75 76 nativeCheckInputs = [ 77 cloudpickle 78 keras 79 einops 80 pytestCheckHook 81 pytest-xdist 82 sphinx 83 tensorflow 84 ]; 85 86 pytestFlags = [ 87 # DeprecationWarning: Triggering of __jax_array__() during abstractification is deprecated. 88 # To avoid this error, either explicitly convert your object using jax.numpy.array(), or register your object as a pytree. 89 "-Wignore::DeprecationWarning" 90 ]; 91 92 disabledTestPaths = [ 93 # Docs test, needs extra deps + we're not interested in it. 94 "docs/_ext/codediff_test.py" 95 96 # The tests in `examples` are not designed to be executed from a single test 97 # session and thus either have the modules that conflict with each other or 98 # wrong import paths, depending on how they're invoked. Many tests also have 99 # dependencies that are not packaged in `nixpkgs` (`clu`, `jgraph`, 100 # `tensorflow_datasets`, `vocabulary`) so the benefits of trying to run them 101 # would be limited anyway. 102 "examples/*" 103 ]; 104 105 disabledTests = [ 106 # AssertionError: [Chex] Function 'add' is traced > 1 times! 107 "PadShardUnpadTest" 108 109 # AssertionError: nnx_model.kernel.value.sharding = NamedSharding(... 110 "test_linen_to_nnx_metadata" 111 ]; 112 113 passthru = { 114 updateScript = writeScript "update.sh" '' 115 nix-update flax # does not --build by default 116 nix-build . -A flax.src # src is essentially a passthru 117 nix-update flaxlib --version="$(${lib.getExe tomlq} <result/Cargo.toml .something.version)" --commit 118 ''; 119 }; 120 121 meta = { 122 description = "Neural network library for JAX"; 123 homepage = "https://github.com/google/flax"; 124 changelog = "https://github.com/google/flax/releases/tag/v${version}"; 125 license = lib.licenses.asl20; 126 maintainers = with lib.maintainers; [ ndl ]; 127 }; 128}