1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5
6 # build-system
7 hatchling,
8
9 # dependencies
10 equinox,
11 jax,
12 jaxtyping,
13 lineax,
14 typing-extensions,
15
16 # tests
17 beartype,
18 jaxlib,
19 optax,
20 pytestCheckHook,
21 pytest-xdist,
22}:
23
24buildPythonPackage rec {
25 pname = "optimistix";
26 version = "0.0.10";
27 pyproject = true;
28
29 src = fetchFromGitHub {
30 owner = "patrick-kidger";
31 repo = "optimistix";
32 tag = "v${version}";
33 hash = "sha256-stVPHzv0XNd0I31N2Cj0QYrMmhImyx0cablqZfKBFrM=";
34 };
35
36 build-system = [ hatchling ];
37
38 dependencies = [
39 equinox
40 jax
41 jaxtyping
42 lineax
43 typing-extensions
44 ];
45
46 pythonImportsCheck = [ "optimistix" ];
47
48 nativeCheckInputs = [
49 beartype
50 jaxlib
51 optax
52 pytestCheckHook
53 pytest-xdist
54 ];
55
56 pytestFlags = [
57 # Since jax 0.5.3:
58 # DeprecationWarning: shape requires ndarray or scalar arguments, got <class 'jax._src.api.ShapeDtypeStruct'> at position 0. In a future JAX release this will be an error.
59 "-Wignore::DeprecationWarning"
60 ];
61
62 disabledTests = [
63 # assert Array(False, dtype=bool)
64 # + where Array(False, dtype=bool) = tree_allclose(Array(0.12993518, dtype=float64), Array(0., dtype=float64, weak_type=True), atol=0.0001, rtol=0.0001)
65 "test_least_squares"
66 ];
67
68 meta = {
69 description = "Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox";
70 homepage = "https://github.com/patrick-kidger/optimistix";
71 changelog = "https://github.com/patrick-kidger/optimistix/releases/tag/v${version}";
72 license = lib.licenses.asl20;
73 maintainers = with lib.maintainers; [ GaetanLepage ];
74 };
75}