1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5
6 # build-system
7 hatchling,
8
9 # dependencies
10 wadler-lindig,
11
12 # tests
13 cloudpickle,
14 equinox,
15 ipython,
16 jax,
17 jaxlib,
18 pytestCheckHook,
19 tensorflow,
20 torch,
21}:
22
23let
24 self = buildPythonPackage rec {
25 pname = "jaxtyping";
26 version = "0.3.2";
27 pyproject = true;
28
29 src = fetchFromGitHub {
30 owner = "google";
31 repo = "jaxtyping";
32 tag = "v${version}";
33 hash = "sha256-zRuTOt9PqFGDZbSGvkzxIWIi3z+vU0FmAEecPRcGy2w=";
34 };
35
36 build-system = [ hatchling ];
37
38 dependencies = [
39 wadler-lindig
40 ];
41
42 pythonImportsCheck = [ "jaxtyping" ];
43
44 nativeCheckInputs = [
45 cloudpickle
46 equinox
47 ipython
48 jax
49 jaxlib
50 pytestCheckHook
51 tensorflow
52 torch
53 ];
54
55 doCheck = false;
56
57 # Enable tests via passthru to avoid cyclic dependency with equinox.
58 passthru.tests = {
59 check = self.overridePythonAttrs {
60 # We disable tests because they complain about the version of typeguard being too new.
61 doCheck = false;
62 catchConflicts = false;
63 };
64 };
65
66 meta = {
67 description = "Type annotations and runtime checking for JAX arrays and PyTrees";
68 homepage = "https://github.com/google/jaxtyping";
69 changelog = "https://github.com/patrick-kidger/jaxtyping/releases/tag/v${version}";
70 license = lib.licenses.mit;
71 maintainers = with lib.maintainers; [ GaetanLepage ];
72 };
73 };
74in
75self