1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5 jax,
6 jaxlib,
7 keras,
8 numpy,
9 parameterized,
10 pillow,
11 pytestCheckHook,
12 pythonOlder,
13 scipy,
14 setuptools,
15 tensorboard,
16 tensorflow,
17}:
18
19buildPythonPackage rec {
20 pname = "objax";
21 version = "1.8.0";
22 pyproject = true;
23
24 disabled = pythonOlder "3.9";
25
26 src = fetchFromGitHub {
27 owner = "google";
28 repo = "objax";
29 tag = "v${version}";
30 hash = "sha256-WD+pmR8cEay4iziRXqF3sHUzCMBjmLJ3wZ3iYOD+hzk=";
31 };
32
33 patches = [
34 # Issue reported upstream: https://github.com/google/objax/issues/270
35 ./replace-deprecated-device_buffers.patch
36 ];
37
38 build-system = [ setuptools ];
39
40 # Avoid propagating the dependency on `jaxlib`, see
41 # https://github.com/NixOS/nixpkgs/issues/156767
42 buildInputs = [ jaxlib ];
43
44 dependencies = [
45 jax
46 numpy
47 parameterized
48 pillow
49 scipy
50 tensorboard
51 ];
52
53 pythonImportsCheck = [ "objax" ];
54
55 # This is necessary to ignore the presence of two protobufs version (tensorflow is bringing an
56 # older version).
57 catchConflicts = false;
58
59 nativeCheckInputs = [
60 keras
61 pytestCheckHook
62 tensorflow
63 ];
64
65 enabledTestPaths = [ "tests/*.py" ];
66
67 disabledTests = [
68 # Test requires internet access for prefetching some weights
69 "test_pretrained_keras_weight_0_ResNet50V2"
70 # ModuleNotFoundError: No module named 'tree'
71 "TestResNetV2Pretrained"
72 ];
73
74 meta = with lib; {
75 description = "Machine learning framework that provides an Object Oriented layer for JAX";
76 homepage = "https://github.com/google/objax";
77 changelog = "https://github.com/google/objax/releases/tag/v${version}";
78 license = licenses.asl20;
79 maintainers = with maintainers; [ ndl ];
80 # Tests test_syncbn_{0,1,2}d and other tests from tests/parallel.py fail
81 broken = true;
82 };
83}