1{
2 buildPythonPackage,
3 cloudpickle,
4 dm-haiku,
5 einops,
6 fetchFromGitHub,
7 flax,
8 hypothesis,
9 jaxlib,
10 keras,
11 lib,
12 poetry-core,
13 pytestCheckHook,
14 pyyaml,
15 rich,
16 tensorflow,
17 treeo,
18 torchmetrics,
19 torch,
20}:
21
22buildPythonPackage rec {
23 pname = "treex";
24 version = "0.6.11";
25 format = "pyproject";
26
27 src = fetchFromGitHub {
28 owner = "cgarciae";
29 repo = "treex";
30 tag = version;
31 hash = "sha256-ObOnbtAT4SlrwOms1jtn7/XKZorGISGY6VuhQlC3DaQ=";
32 };
33
34 # At the time of writing (2022-03-29), rich is currently at version 11.0.0.
35 # The treeo dependency is compatible with a patch, but not marked as such in
36 # treex. See https://github.com/cgarciae/treex/issues/68.
37 pythonRelaxDeps = [
38 "certifi"
39 "flax"
40 "rich"
41 "treeo"
42 ];
43
44 nativeBuildInputs = [
45 poetry-core
46 ];
47
48 buildInputs = [ jaxlib ];
49
50 propagatedBuildInputs = [
51 einops
52 flax
53 pyyaml
54 rich
55 treeo
56 torch
57 ];
58
59 nativeCheckInputs = [
60 cloudpickle
61 dm-haiku
62 hypothesis
63 keras
64 pytestCheckHook
65 tensorflow
66 torchmetrics
67 ];
68
69 pythonImportsCheck = [ "treex" ];
70
71 meta = with lib; {
72 description = "Pytree Module system for Deep Learning in JAX";
73 homepage = "https://github.com/cgarciae/treex";
74 license = licenses.mit;
75 maintainers = with maintainers; [ ndl ];
76 };
77}