1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5
6 # build-system
7 setuptools,
8
9 # dependencies
10 absl-py,
11 jax,
12 matplotlib,
13 numpy,
14 scipy,
15
16 # tests
17 cvxpy,
18 optax,
19 pytest-xdist,
20 pytestCheckHook,
21 scikit-learn,
22}:
23
24buildPythonPackage rec {
25 pname = "jaxopt";
26 version = "0.8.5";
27 pyproject = true;
28
29 src = fetchFromGitHub {
30 owner = "google";
31 repo = "jaxopt";
32 tag = "jaxopt-v${version}";
33 hash = "sha256-vPXrs8J81O+27w9P/fEFr7w4xClKb8T0IASD+iNhztQ=";
34 };
35
36 build-system = [ setuptools ];
37
38 dependencies = [
39 absl-py
40 jax
41 matplotlib
42 numpy
43 scipy
44 ];
45
46 nativeCheckInputs = [
47 cvxpy
48 optax
49 pytest-xdist
50 pytestCheckHook
51 scikit-learn
52 ];
53
54 pythonImportsCheck = [
55 "jaxopt"
56 "jaxopt.implicit_diff"
57 "jaxopt.linear_solve"
58 "jaxopt.loss"
59 "jaxopt.tree_util"
60 ];
61
62 disabledTests = [
63 # https://github.com/google/jaxopt/issues/592
64 "test_solve_sparse"
65
66 # https://github.com/google/jaxopt/issues/593
67 # Makes the test suite crash
68 "test_dtype_consistency"
69
70 # AssertionError: Not equal to tolerance rtol=1e-06, atol=1e-06
71 # https://github.com/google/jaxopt/issues/618
72 "test_binary_logit_log_likelihood"
73
74 # AssertionError (flaky numerical tests)
75 "test_Rosenbrock2"
76 "test_Rosenbrock5"
77 "test_gradient1"
78 "test_inv_hessian_product_pytree3"
79 "test_logreg_with_intercept_manual_loop3"
80 "test_multiclass_logreg6"
81 ];
82
83 meta = {
84 homepage = "https://jaxopt.github.io";
85 description = "Hardware accelerated, batchable and differentiable optimizers in JAX";
86 changelog = "https://github.com/google/jaxopt/releases/tag/jaxopt-v${version}";
87 license = lib.licenses.asl20;
88 maintainers = with lib.maintainers; [ bcdarwin ];
89 };
90}