at master 1.8 kB view raw
1{ 2 lib, 3 buildPythonPackage, 4 fetchFromGitHub, 5 jax, 6 jaxlib, 7 keras, 8 numpy, 9 parameterized, 10 pillow, 11 pytestCheckHook, 12 pythonOlder, 13 scipy, 14 setuptools, 15 tensorboard, 16 tensorflow, 17}: 18 19buildPythonPackage rec { 20 pname = "objax"; 21 version = "1.8.0"; 22 pyproject = true; 23 24 disabled = pythonOlder "3.9"; 25 26 src = fetchFromGitHub { 27 owner = "google"; 28 repo = "objax"; 29 tag = "v${version}"; 30 hash = "sha256-WD+pmR8cEay4iziRXqF3sHUzCMBjmLJ3wZ3iYOD+hzk="; 31 }; 32 33 patches = [ 34 # Issue reported upstream: https://github.com/google/objax/issues/270 35 ./replace-deprecated-device_buffers.patch 36 ]; 37 38 build-system = [ setuptools ]; 39 40 # Avoid propagating the dependency on `jaxlib`, see 41 # https://github.com/NixOS/nixpkgs/issues/156767 42 buildInputs = [ jaxlib ]; 43 44 dependencies = [ 45 jax 46 numpy 47 parameterized 48 pillow 49 scipy 50 tensorboard 51 ]; 52 53 pythonImportsCheck = [ "objax" ]; 54 55 # This is necessary to ignore the presence of two protobufs version (tensorflow is bringing an 56 # older version). 57 catchConflicts = false; 58 59 nativeCheckInputs = [ 60 keras 61 pytestCheckHook 62 tensorflow 63 ]; 64 65 enabledTestPaths = [ "tests/*.py" ]; 66 67 disabledTests = [ 68 # Test requires internet access for prefetching some weights 69 "test_pretrained_keras_weight_0_ResNet50V2" 70 # ModuleNotFoundError: No module named 'tree' 71 "TestResNetV2Pretrained" 72 ]; 73 74 meta = with lib; { 75 description = "Machine learning framework that provides an Object Oriented layer for JAX"; 76 homepage = "https://github.com/google/objax"; 77 changelog = "https://github.com/google/objax/releases/tag/v${version}"; 78 license = licenses.asl20; 79 maintainers = with maintainers; [ ndl ]; 80 # Tests test_syncbn_{0,1,2}d and other tests from tests/parallel.py fail 81 broken = true; 82 }; 83}