at master 1.8 kB view raw
1{ 2 lib, 3 buildPythonPackage, 4 fetchFromGitHub, 5 6 # build-system 7 setuptools, 8 9 # dependencies 10 absl-py, 11 jax, 12 matplotlib, 13 numpy, 14 scipy, 15 16 # tests 17 cvxpy, 18 optax, 19 pytest-xdist, 20 pytestCheckHook, 21 scikit-learn, 22}: 23 24buildPythonPackage rec { 25 pname = "jaxopt"; 26 version = "0.8.5"; 27 pyproject = true; 28 29 src = fetchFromGitHub { 30 owner = "google"; 31 repo = "jaxopt"; 32 tag = "jaxopt-v${version}"; 33 hash = "sha256-vPXrs8J81O+27w9P/fEFr7w4xClKb8T0IASD+iNhztQ="; 34 }; 35 36 build-system = [ setuptools ]; 37 38 dependencies = [ 39 absl-py 40 jax 41 matplotlib 42 numpy 43 scipy 44 ]; 45 46 nativeCheckInputs = [ 47 cvxpy 48 optax 49 pytest-xdist 50 pytestCheckHook 51 scikit-learn 52 ]; 53 54 pythonImportsCheck = [ 55 "jaxopt" 56 "jaxopt.implicit_diff" 57 "jaxopt.linear_solve" 58 "jaxopt.loss" 59 "jaxopt.tree_util" 60 ]; 61 62 disabledTests = [ 63 # https://github.com/google/jaxopt/issues/592 64 "test_solve_sparse" 65 66 # https://github.com/google/jaxopt/issues/593 67 # Makes the test suite crash 68 "test_dtype_consistency" 69 70 # AssertionError: Not equal to tolerance rtol=1e-06, atol=1e-06 71 # https://github.com/google/jaxopt/issues/618 72 "test_binary_logit_log_likelihood" 73 74 # AssertionError (flaky numerical tests) 75 "test_Rosenbrock2" 76 "test_Rosenbrock5" 77 "test_gradient1" 78 "test_inv_hessian_product_pytree3" 79 "test_logreg_with_intercept_manual_loop3" 80 "test_multiclass_logreg6" 81 ]; 82 83 meta = { 84 homepage = "https://jaxopt.github.io"; 85 description = "Hardware accelerated, batchable and differentiable optimizers in JAX"; 86 changelog = "https://github.com/google/jaxopt/releases/tag/jaxopt-v${version}"; 87 license = lib.licenses.asl20; 88 maintainers = with lib.maintainers; [ bcdarwin ]; 89 }; 90}