1{
2 lib,
3 stdenv,
4 buildPythonPackage,
5 fetchFromGitHub,
6
7 # build-system
8 setuptools,
9
10 # dependencies
11 absl-py,
12 distutils,
13 h5py,
14 ml-dtypes,
15 namex,
16 numpy,
17 tf2onnx,
18 onnxruntime,
19 optree,
20 packaging,
21 pythonAtLeast,
22 rich,
23 scikit-learn,
24 tensorflow,
25
26 # tests
27 dm-tree,
28 jax,
29 pandas,
30 pydot,
31 pytestCheckHook,
32 tf-keras,
33 torch,
34 writableTmpDirAsHomeHook,
35}:
36
37buildPythonPackage rec {
38 pname = "keras";
39 version = "3.11.3";
40 pyproject = true;
41
42 src = fetchFromGitHub {
43 owner = "keras-team";
44 repo = "keras";
45 tag = "v${version}";
46 hash = "sha256-J/NPLR9ShKhvHDU0/NpUNp95RViS2KygqvnuDHdwiP0=";
47 };
48
49 build-system = [
50 setuptools
51 ];
52
53 dependencies = [
54 absl-py
55 h5py
56 ml-dtypes
57 namex
58 numpy
59 tf2onnx
60 onnxruntime
61 optree
62 packaging
63 rich
64 scikit-learn
65 tensorflow
66 ]
67 ++ lib.optionals (pythonAtLeast "3.12") [ distutils ];
68
69 pythonImportsCheck = [
70 "keras"
71 "keras._tf_keras"
72 ];
73
74 nativeCheckInputs = [
75 dm-tree
76 jax
77 pandas
78 pydot
79 pytestCheckHook
80 tf-keras
81 torch
82 writableTmpDirAsHomeHook
83 ];
84
85 disabledTests = [
86 # Require unpackaged `grain`
87 "test_fit_with_data_adapter_grain_dataloader"
88 "test_fit_with_data_adapter_grain_datast"
89 "test_fit_with_data_adapter_grain_datast_with_len"
90
91 # Tries to install the package in the sandbox
92 "test_keras_imports"
93
94 # TypeError: this __dict__ descriptor does not support '_DictWrapper' objects
95 "test_reloading_default_saved_model"
96
97 # E AssertionError:
98 # E - float32
99 # E + float64
100 "test_angle_bool"
101 "test_angle_int16"
102 "test_angle_int32"
103 "test_angle_int8"
104 "test_angle_uint16"
105 "test_angle_uint32"
106 "test_angle_uint8"
107 "test_bartlett_bfloat16"
108 "test_bartlett_bool"
109 "test_bartlett_float16"
110 "test_bartlett_float32"
111 "test_bartlett_float64"
112 "test_bartlett_int16"
113 "test_bartlett_int32"
114 "test_bartlett_int64"
115 "test_bartlett_int8"
116 "test_bartlett_none"
117 "test_bartlett_uint16"
118 "test_bartlett_uint32"
119 "test_bartlett_uint8"
120 "test_blackman_bfloat16"
121 "test_blackman_bool"
122 "test_blackman_float16"
123 "test_blackman_float32"
124 "test_blackman_float64"
125 "test_blackman_int16"
126 "test_blackman_int32"
127 "test_blackman_int64"
128 "test_blackman_int8"
129 "test_blackman_none"
130 "test_blackman_uint16"
131 "test_blackman_uint32"
132 "test_blackman_uint8"
133 "test_eye_none"
134 "test_hamming_bfloat16"
135 "test_hamming_bool"
136 "test_hamming_float16"
137 "test_hamming_float32"
138 "test_hamming_float64"
139 "test_hamming_int16"
140 "test_hamming_int32"
141 "test_hamming_int64"
142 "test_hamming_int8"
143 "test_hamming_none"
144 "test_hamming_uint16"
145 "test_hamming_uint32"
146 "test_hamming_uint8"
147 "test_hanning_bfloat16"
148 "test_hanning_bool"
149 "test_hanning_float16"
150 "test_hanning_float32"
151 "test_hanning_float64"
152 "test_hanning_int16"
153 "test_hanning_int32"
154 "test_hanning_int64"
155 "test_hanning_int8"
156 "test_hanning_none"
157 "test_hanning_uint16"
158 "test_hanning_uint32"
159 "test_hanning_uint8"
160 "test_identity_none"
161 "test_kaiser_bfloat16"
162 "test_kaiser_bool"
163 "test_kaiser_float16"
164 "test_kaiser_float32"
165 "test_kaiser_float64"
166 "test_kaiser_int16"
167 "test_kaiser_int32"
168 "test_kaiser_int64"
169 "test_kaiser_int8"
170 "test_kaiser_none"
171 "test_kaiser_uint16"
172 "test_kaiser_uint32"
173 "test_kaiser_uint8"
174 ]
175 ++ lib.optionals (stdenv.hostPlatform.isLinux && stdenv.hostPlatform.isAarch64) [
176 # Hangs forever
177 "test_fit_with_data_adapter"
178 ];
179
180 disabledTestPaths = [
181 # Require unpackaged `grain`
182 "keras/src/trainers/data_adapters/grain_dataset_adapter_test.py"
183
184 # These tests succeed when run individually, but crash within the full test suite:
185 # ImportError: /nix/store/4bw0x7j3wfbh6i8x3plmzknrdwdzwfla-abseil-cpp-20240722.1/lib/libabsl_cord_internal.so.2407.0.0:
186 # undefined symbol: _ZN4absl12lts_2024072216strings_internal13StringifySink6AppendESt17basic_string_viewIcSt11char_traitsIcEE
187 "keras/src/export/onnx_test.py"
188
189 # Require internet access
190 "integration_tests/dataset_tests"
191 "keras/src/applications/applications_test.py"
192
193 # TypeError: test_custom_fit.<locals>.CustomModel.train_step() missing 1 required positional argument: 'data'
194 "integration_tests/jax_custom_fit_test.py"
195
196 # RuntimeError: Virtual devices cannot be modified after being initialized
197 "integration_tests/tf_distribute_training_test.py"
198
199 # AttributeError: 'CustomModel' object has no attribute 'zero_grad'
200 "integration_tests/torch_custom_fit_test.py"
201
202 # Fails for an unclear reason:
203 # self.assertLen(list(net.parameters()), 2
204 # AssertionError: 0 != 2
205 "integration_tests/torch_workflow_test.py"
206
207 # TypeError: this __dict__ descriptor does not support '_DictWrapper' objects
208 "keras/src/backend/tensorflow/saved_model_test.py"
209 ];
210
211 meta = {
212 description = "Multi-backend implementation of the Keras API, with support for TensorFlow, JAX, and PyTorch";
213 homepage = "https://keras.io";
214 changelog = "https://github.com/keras-team/keras/releases/tag/v${version}";
215 license = lib.licenses.mit;
216 maintainers = with lib.maintainers; [ GaetanLepage ];
217 };
218}