1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5 fetchpatch,
6
7 # build-system
8 setuptools,
9
10 # dependencies
11 absl-py,
12 chex,
13 distrax,
14 dm-env,
15 jax,
16 jaxlib,
17 numpy,
18 tensorflow-probability,
19
20 # tests
21 dm-haiku,
22 optax,
23 pytest-xdist,
24 pytestCheckHook,
25}:
26
27buildPythonPackage rec {
28 pname = "rlax";
29 version = "0.1.8";
30 pyproject = true;
31
32 src = fetchFromGitHub {
33 owner = "google-deepmind";
34 repo = "rlax";
35 tag = "v${version}";
36 hash = "sha256-E/zYFd5bfx58FfA3uR7hzRAIs844QzJA8TZTwmwDByk=";
37 };
38
39 build-system = [
40 setuptools
41 ];
42
43 dependencies = [
44 absl-py
45 chex
46 distrax
47 dm-env
48 jax
49 jaxlib
50 numpy
51 ];
52
53 nativeCheckInputs = [
54 dm-haiku
55 optax
56 pytest-xdist
57 pytestCheckHook
58 ];
59
60 pythonImportsCheck = [ "rlax" ];
61
62 disabledTests = [
63 # AssertionError: Array(2, dtype=int32) != 0
64 "test_categorical_sample__with_device"
65 "test_categorical_sample__with_jit"
66 "test_categorical_sample__without_device"
67 "test_categorical_sample__without_jit"
68
69 # RuntimeError: Attempted to set 4 devices, but 1 CPUs already available:
70 # ensure that `set_n_cpu_devices` is executed before any JAX operation.
71 "test_cross_replica_scatter_add0"
72 "test_cross_replica_scatter_add1"
73 "test_cross_replica_scatter_add2"
74 "test_cross_replica_scatter_add3"
75 "test_cross_replica_scatter_add4"
76 "test_learn_scale_shift"
77 "test_normalize_unnormalize_is_identity"
78 "test_outputs_preserved"
79 "test_scale_bounded"
80 "test_slow_update"
81 "test_unnormalize_linear"
82 ];
83
84 meta = {
85 description = "Library of reinforcement learning building blocks in JAX";
86 homepage = "https://github.com/deepmind/rlax";
87 changelog = "https://github.com/google-deepmind/rlax/releases/tag/${src.tag}";
88 license = lib.licenses.asl20;
89 maintainers = with lib.maintainers; [ onny ];
90 };
91}