1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5 stdenv,
6
7 # build-system
8 hatchling,
9
10 # dependencies
11 absl-py,
12 etils,
13 flask,
14 flask-cors,
15 flax,
16 jax,
17 jaxlib,
18 jaxopt,
19 jinja2,
20 ml-collections,
21 mujoco,
22 mujoco-mjx,
23 numpy,
24 optax,
25 orbax-checkpoint,
26 pillow,
27 scipy,
28 tensorboardx,
29 typing-extensions,
30
31 # tests
32 dm-env,
33 gym,
34 pytestCheckHook,
35 pytest-xdist,
36 transforms3d,
37}:
38
39buildPythonPackage rec {
40 pname = "brax";
41 version = "0.13.0";
42 pyproject = true;
43
44 src = fetchFromGitHub {
45 owner = "google";
46 repo = "brax";
47 tag = "v${version}";
48 hash = "sha256-mSFbFzSrfAvAE6y7atUeucUkpp/20KP70j5xPm/xvB0=";
49 };
50
51 build-system = [
52 hatchling
53 ];
54
55 dependencies = [
56 absl-py
57 etils
58 flask
59 flask-cors
60 flax
61 jax
62 jaxlib
63 jaxopt
64 jinja2
65 ml-collections
66 mujoco
67 mujoco-mjx
68 numpy
69 optax
70 orbax-checkpoint
71 pillow
72 scipy
73 tensorboardx
74 typing-extensions
75 ];
76
77 nativeCheckInputs = [
78 dm-env
79 gym
80 pytestCheckHook
81 pytest-xdist
82 transforms3d
83 ];
84
85 disabledTests = [
86 # AttributeError: 'functools.partial' object has no attribute 'value'
87 "testModelEncoding0"
88 "testModelEncoding1"
89 "testTrain"
90 "testTrainDomainRandomize"
91 ]
92 ++ lib.optionals stdenv.hostPlatform.isAarch64 [
93 # Flaky:
94 # AssertionError: Array(-0.00135638, dtype=float32) != 0.0 within 0.001 delta (Array(0.00135638, dtype=float32) difference)
95 "test_pendulum_period2"
96 ];
97
98 disabledTestPaths = [
99 # ValueError: matmul: Input operand 1 has a mismatch in its core dimension
100 "brax/generalized/constraint_test.py"
101 ];
102
103 pythonImportsCheck = [
104 "brax"
105 ];
106
107 meta = {
108 description = "Massively parallel rigidbody physics simulation on accelerator hardware";
109 homepage = "https://github.com/google/brax";
110 license = lib.licenses.asl20;
111 maintainers = with lib.maintainers; [ nim65s ];
112 };
113}