at master 1.2 kB view raw
1{ 2 lib, 3 buildPythonPackage, 4 fetchFromGitHub, 5 6 # build-system 7 flit-core, 8 9 # dependencies 10 absl-py, 11 chex, 12 jax, 13 jaxlib, 14 numpy, 15 16 # tests 17 callPackage, 18}: 19 20buildPythonPackage rec { 21 pname = "optax"; 22 version = "0.2.6"; 23 pyproject = true; 24 25 src = fetchFromGitHub { 26 owner = "deepmind"; 27 repo = "optax"; 28 tag = "v${version}"; 29 hash = "sha256-+9Q/Amb60m65ZiJsmH93e6tQmpJlMyzVUL0A7q3mS8Y="; 30 }; 31 32 outputs = [ 33 "out" 34 "testsout" 35 ]; 36 37 build-system = [ flit-core ]; 38 39 dependencies = [ 40 absl-py 41 chex 42 jax 43 jaxlib 44 numpy 45 ]; 46 47 postInstall = '' 48 mkdir $testsout 49 cp -R examples $testsout/examples 50 ''; 51 52 pythonImportsCheck = [ "optax" ]; 53 54 # check in passthru.tests.pytest to escape infinite recursion with flax 55 doCheck = false; 56 57 passthru.tests = { 58 pytest = callPackage ./tests.nix { }; 59 }; 60 61 meta = { 62 description = "Gradient processing and optimization library for JAX"; 63 homepage = "https://github.com/deepmind/optax"; 64 changelog = "https://github.com/deepmind/optax/releases/tag/v${version}"; 65 license = lib.licenses.asl20; 66 maintainers = with lib.maintainers; [ ndl ]; 67 }; 68}