1{
2 lib,
3 stdenv,
4 absl-py,
5 buildPythonPackage,
6 fetchFromGitHub,
7
8 # build-system
9 flit-core,
10
11 # dependencies
12 aiofiles,
13 etils,
14 humanize,
15 importlib-resources,
16 jax,
17 msgpack,
18 nest-asyncio,
19 numpy,
20 protobuf,
21 pyyaml,
22 simplejson,
23 tensorstore,
24 typing-extensions,
25
26 # tests
27 chex,
28 google-cloud-logging,
29 mock,
30 optax,
31 portpicker,
32 pytest-xdist,
33 pytestCheckHook,
34 safetensors,
35}:
36
37buildPythonPackage rec {
38 pname = "orbax-checkpoint";
39 version = "0.11.25";
40 pyproject = true;
41
42 src = fetchFromGitHub {
43 owner = "google";
44 repo = "orbax";
45 tag = "v${version}";
46 hash = "sha256-myhPWKP2uI9NQKZki1Rr+B6Kusn0qNWREKHkiDrSheA=";
47 };
48
49 sourceRoot = "${src.name}/checkpoint";
50
51 build-system = [ flit-core ];
52
53 pythonRelaxDeps = [
54 "jax"
55 ];
56
57 dependencies = [
58 absl-py
59 aiofiles
60 etils
61 humanize
62 importlib-resources
63 jax
64 msgpack
65 nest-asyncio
66 numpy
67 protobuf
68 pyyaml
69 simplejson
70 tensorstore
71 typing-extensions
72 ];
73
74 nativeCheckInputs = [
75 chex
76 google-cloud-logging
77 mock
78 optax
79 portpicker
80 pytest-xdist
81 pytestCheckHook
82 safetensors
83 ];
84
85 pythonImportsCheck = [
86 "orbax"
87 "orbax.checkpoint"
88 ];
89
90 disabledTests = [
91 # Flaky
92 # AssertionError: 2 not greater than 2.0046136379241943
93 "test_async_mkdir_parallel"
94 "test_async_mkdir_sequential"
95 ]
96 ++ lib.optionals stdenv.hostPlatform.isDarwin [
97 # Probably failing because of a filesystem impurity
98 # self.assertFalse(os.path.exists(dst_dir))
99 # AssertionError: True is not false
100 "test_create_snapshot"
101 ];
102
103 disabledTestPaths = [
104 # E absl.flags._exceptions.DuplicateFlagError: The flag 'num_processes' is defined twice.
105 # First from multiprocess_test, Second from orbax.checkpoint._src.testing.multiprocess_test.
106 # Description from first occurrence: Number of processes to use.
107 # https://github.com/google/orbax/issues/1580
108 "orbax/checkpoint/experimental/emergency/"
109
110 # E FileNotFoundError: [Errno 2] No such file or directory:
111 # '/build/absl_testing/DefaultSnapshotTest/runTest/root/path/to/source/data.txt'
112 "orbax/checkpoint/_src/path/snapshot/snapshot_test.py"
113
114 # Circular dependency flax
115 "orbax/checkpoint/_src/metadata/empty_values_test.py"
116 "orbax/checkpoint/_src/metadata/tree_rich_types_test.py"
117 "orbax/checkpoint/_src/metadata/tree_test.py"
118 "orbax/checkpoint/_src/testing/test_tree_utils.py"
119 "orbax/checkpoint/_src/tree/parts_of_test.py"
120 "orbax/checkpoint/_src/tree/structure_utils_test.py"
121 "orbax/checkpoint/_src/tree/utils_test.py"
122 "orbax/checkpoint/single_host_test.py"
123 "orbax/checkpoint/transform_utils_test.py"
124 ];
125
126 meta = {
127 description = "Orbax provides common utility libraries for JAX users";
128 homepage = "https://github.com/google/orbax/tree/main/checkpoint";
129 changelog = "https://github.com/google/orbax/blob/v${version}/checkpoint/CHANGELOG.md";
130 license = lib.licenses.asl20;
131 maintainers = with lib.maintainers; [ fab ];
132 };
133}