1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5
6 # build-system
7 flit-core,
8
9 # dependencies
10 absl-py,
11 chex,
12 jax,
13 jaxlib,
14 numpy,
15
16 # tests
17 callPackage,
18}:
19
20buildPythonPackage rec {
21 pname = "optax";
22 version = "0.2.6";
23 pyproject = true;
24
25 src = fetchFromGitHub {
26 owner = "deepmind";
27 repo = "optax";
28 tag = "v${version}";
29 hash = "sha256-+9Q/Amb60m65ZiJsmH93e6tQmpJlMyzVUL0A7q3mS8Y=";
30 };
31
32 outputs = [
33 "out"
34 "testsout"
35 ];
36
37 build-system = [ flit-core ];
38
39 dependencies = [
40 absl-py
41 chex
42 jax
43 jaxlib
44 numpy
45 ];
46
47 postInstall = ''
48 mkdir $testsout
49 cp -R examples $testsout/examples
50 '';
51
52 pythonImportsCheck = [ "optax" ];
53
54 # check in passthru.tests.pytest to escape infinite recursion with flax
55 doCheck = false;
56
57 passthru.tests = {
58 pytest = callPackage ./tests.nix { };
59 };
60
61 meta = {
62 description = "Gradient processing and optimization library for JAX";
63 homepage = "https://github.com/deepmind/optax";
64 changelog = "https://github.com/deepmind/optax/releases/tag/v${version}";
65 license = lib.licenses.asl20;
66 maintainers = with lib.maintainers; [ ndl ];
67 };
68}