1{ 2 lib, 3 buildPythonPackage, 4 fetchFromGitHub, 5 6 # build-system 7 hatchling, 8 9 # dependencies 10 equinox, 11 jax, 12 jaxtyping, 13 lineax, 14 typing-extensions, 15 16 # tests 17 beartype, 18 jaxlib, 19 optax, 20 pytestCheckHook, 21 pytest-xdist, 22}: 23 24buildPythonPackage rec { 25 pname = "optimistix"; 26 version = "0.0.10"; 27 pyproject = true; 28 29 src = fetchFromGitHub { 30 owner = "patrick-kidger"; 31 repo = "optimistix"; 32 tag = "v${version}"; 33 hash = "sha256-stVPHzv0XNd0I31N2Cj0QYrMmhImyx0cablqZfKBFrM="; 34 }; 35 36 build-system = [ hatchling ]; 37 38 dependencies = [ 39 equinox 40 jax 41 jaxtyping 42 lineax 43 typing-extensions 44 ]; 45 46 pythonImportsCheck = [ "optimistix" ]; 47 48 nativeCheckInputs = [ 49 beartype 50 jaxlib 51 optax 52 pytestCheckHook 53 pytest-xdist 54 ]; 55 56 pytestFlags = [ 57 # Since jax 0.5.3: 58 # DeprecationWarning: shape requires ndarray or scalar arguments, got <class 'jax._src.api.ShapeDtypeStruct'> at position 0. In a future JAX release this will be an error. 59 "-Wignore::DeprecationWarning" 60 ]; 61 62 disabledTests = [ 63 # assert Array(False, dtype=bool) 64 # + where Array(False, dtype=bool) = tree_allclose(Array(0.12993518, dtype=float64), Array(0., dtype=float64, weak_type=True), atol=0.0001, rtol=0.0001) 65 "test_least_squares" 66 ]; 67 68 meta = { 69 description = "Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox"; 70 homepage = "https://github.com/patrick-kidger/optimistix"; 71 changelog = "https://github.com/patrick-kidger/optimistix/releases/tag/v${version}"; 72 license = lib.licenses.asl20; 73 maintainers = with lib.maintainers; [ GaetanLepage ]; 74 }; 75}