1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5
6 # build-system
7 hatchling,
8
9 # dependencies
10 equinox,
11 jax,
12 jaxtyping,
13 typing-extensions,
14
15 # tests
16 beartype,
17 pytest,
18 python,
19}:
20
21buildPythonPackage rec {
22 pname = "lineax";
23 version = "0.0.8";
24 pyproject = true;
25
26 src = fetchFromGitHub {
27 owner = "patrick-kidger";
28 repo = "lineax";
29 tag = "v${version}";
30 hash = "sha256-VMTDCExgxfCcd/3UZAglfAxAFaSjzFJJuvSWJAx2tJs=";
31 };
32
33 build-system = [ hatchling ];
34
35 dependencies = [
36 equinox
37 jax
38 jaxtyping
39 typing-extensions
40 ];
41
42 pythonImportsCheck = [ "lineax" ];
43
44 nativeCheckInputs = [
45 beartype
46 pytest
47 ];
48
49 # Intentionally not using pytest directly as it leads to JAX out-of-memory'ing
50 # https://github.com/patrick-kidger/lineax/blob/1909d190c1963d5f2d991508c1b2714f2266048b/tests/README.md
51 checkPhase = ''
52 runHook preCheck
53
54 ${python.interpreter} -m tests
55
56 runHook postCheck
57 '';
58
59 meta = {
60 description = "Linear solvers in JAX and Equinox";
61 homepage = "https://github.com/patrick-kidger/lineax";
62 changelog = "https://github.com/patrick-kidger/lineax/releases/tag/${src.tag}";
63 license = lib.licenses.asl20;
64 maintainers = with lib.maintainers; [ GaetanLepage ];
65 };
66}