1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5
6 # build-system
7 setuptools,
8 setuptools-scm,
9
10 # dependencies
11 jax,
12 msgpack,
13 numpy,
14 optax,
15 orbax-checkpoint,
16 pyyaml,
17 rich,
18 tensorstore,
19 typing-extensions,
20
21 # optional-dependencies
22 matplotlib,
23
24 # tests
25 cloudpickle,
26 keras,
27 einops,
28 flaxlib,
29 pytestCheckHook,
30 pytest-xdist,
31 sphinx,
32 tensorflow,
33 treescope,
34
35 writeScript,
36 tomlq,
37}:
38
39buildPythonPackage rec {
40 pname = "flax";
41 version = "0.12.0";
42 pyproject = true;
43
44 src = fetchFromGitHub {
45 owner = "google";
46 repo = "flax";
47 tag = "v${version}";
48 hash = "sha256-ioMj8+TuOFX3t9p3oVaywaOQPFBgvNcy7b/2WX/yvXA=";
49 };
50
51 build-system = [
52 setuptools
53 setuptools-scm
54 ];
55
56 dependencies = [
57 flaxlib
58 jax
59 msgpack
60 numpy
61 optax
62 orbax-checkpoint
63 pyyaml
64 rich
65 tensorstore
66 treescope
67 typing-extensions
68 ];
69
70 optional-dependencies = {
71 all = [ matplotlib ];
72 };
73
74 pythonImportsCheck = [ "flax" ];
75
76 nativeCheckInputs = [
77 cloudpickle
78 keras
79 einops
80 pytestCheckHook
81 pytest-xdist
82 sphinx
83 tensorflow
84 ];
85
86 pytestFlags = [
87 # DeprecationWarning: Triggering of __jax_array__() during abstractification is deprecated.
88 # To avoid this error, either explicitly convert your object using jax.numpy.array(), or register your object as a pytree.
89 "-Wignore::DeprecationWarning"
90 ];
91
92 disabledTestPaths = [
93 # Docs test, needs extra deps + we're not interested in it.
94 "docs/_ext/codediff_test.py"
95
96 # The tests in `examples` are not designed to be executed from a single test
97 # session and thus either have the modules that conflict with each other or
98 # wrong import paths, depending on how they're invoked. Many tests also have
99 # dependencies that are not packaged in `nixpkgs` (`clu`, `jgraph`,
100 # `tensorflow_datasets`, `vocabulary`) so the benefits of trying to run them
101 # would be limited anyway.
102 "examples/*"
103 ];
104
105 disabledTests = [
106 # AssertionError: [Chex] Function 'add' is traced > 1 times!
107 "PadShardUnpadTest"
108
109 # AssertionError: nnx_model.kernel.value.sharding = NamedSharding(...
110 "test_linen_to_nnx_metadata"
111 ];
112
113 passthru = {
114 updateScript = writeScript "update.sh" ''
115 nix-update flax # does not --build by default
116 nix-build . -A flax.src # src is essentially a passthru
117 nix-update flaxlib --version="$(${lib.getExe tomlq} <result/Cargo.toml .something.version)" --commit
118 '';
119 };
120
121 meta = {
122 description = "Neural network library for JAX";
123 homepage = "https://github.com/google/flax";
124 changelog = "https://github.com/google/flax/releases/tag/v${version}";
125 license = lib.licenses.asl20;
126 maintainers = with lib.maintainers; [ ndl ];
127 };
128}