at master 1.9 kB view raw
1{ 2 lib, 3 buildPythonPackage, 4 fetchFromGitHub, 5 fetchpatch, 6 7 # build-system 8 setuptools, 9 10 # dependencies 11 absl-py, 12 chex, 13 distrax, 14 dm-env, 15 jax, 16 jaxlib, 17 numpy, 18 tensorflow-probability, 19 20 # tests 21 dm-haiku, 22 optax, 23 pytest-xdist, 24 pytestCheckHook, 25}: 26 27buildPythonPackage rec { 28 pname = "rlax"; 29 version = "0.1.8"; 30 pyproject = true; 31 32 src = fetchFromGitHub { 33 owner = "google-deepmind"; 34 repo = "rlax"; 35 tag = "v${version}"; 36 hash = "sha256-E/zYFd5bfx58FfA3uR7hzRAIs844QzJA8TZTwmwDByk="; 37 }; 38 39 build-system = [ 40 setuptools 41 ]; 42 43 dependencies = [ 44 absl-py 45 chex 46 distrax 47 dm-env 48 jax 49 jaxlib 50 numpy 51 ]; 52 53 nativeCheckInputs = [ 54 dm-haiku 55 optax 56 pytest-xdist 57 pytestCheckHook 58 ]; 59 60 pythonImportsCheck = [ "rlax" ]; 61 62 disabledTests = [ 63 # AssertionError: Array(2, dtype=int32) != 0 64 "test_categorical_sample__with_device" 65 "test_categorical_sample__with_jit" 66 "test_categorical_sample__without_device" 67 "test_categorical_sample__without_jit" 68 69 # RuntimeError: Attempted to set 4 devices, but 1 CPUs already available: 70 # ensure that `set_n_cpu_devices` is executed before any JAX operation. 71 "test_cross_replica_scatter_add0" 72 "test_cross_replica_scatter_add1" 73 "test_cross_replica_scatter_add2" 74 "test_cross_replica_scatter_add3" 75 "test_cross_replica_scatter_add4" 76 "test_learn_scale_shift" 77 "test_normalize_unnormalize_is_identity" 78 "test_outputs_preserved" 79 "test_scale_bounded" 80 "test_slow_update" 81 "test_unnormalize_linear" 82 ]; 83 84 meta = { 85 description = "Library of reinforcement learning building blocks in JAX"; 86 homepage = "https://github.com/deepmind/rlax"; 87 changelog = "https://github.com/google-deepmind/rlax/releases/tag/${src.tag}"; 88 license = lib.licenses.asl20; 89 maintainers = with lib.maintainers; [ onny ]; 90 }; 91}