at master 6.4 kB view raw
1{ 2 lib, 3 stdenv, 4 buildPythonPackage, 5 fetchFromGitHub, 6 7 # build-system 8 cmake, 9 ninja, 10 numpy, 11 pybind11, 12 setuptools, 13 torch, 14 15 # dependencies 16 cloudpickle, 17 packaging, 18 pyvers, 19 tensordict, 20 21 # optional-dependencies 22 # atari 23 gymnasium, 24 # brax 25 brax, 26 jax, 27 # checkpointing 28 torchsnapshot, 29 # dm-control 30 dm-control, 31 # gym-continuous 32 mujoco, 33 # llm 34 accelerate, 35 datasets, 36 einops, 37 immutabledict, 38 langdetect, 39 nltk, 40 playwright, 41 protobuf, 42 safetensors, 43 sentencepiece, 44 transformers, 45 vllm, 46 # marl 47 pettingzoo, 48 # offline-data 49 h5py, 50 huggingface-hub, 51 minari, 52 pandas, 53 pillow, 54 requests, 55 scikit-learn, 56 torchvision, 57 tqdm, 58 # rendering 59 moviepy, 60 # utils 61 git, 62 hydra-core, 63 tensorboard, 64 wandb, 65 66 # tests 67 imageio, 68 pytest-rerunfailures, 69 pytestCheckHook, 70 pyyaml, 71 scipy, 72}: 73 74buildPythonPackage rec { 75 pname = "torchrl"; 76 version = "0.10.0"; 77 pyproject = true; 78 79 src = fetchFromGitHub { 80 owner = "pytorch"; 81 repo = "rl"; 82 tag = "v${version}"; 83 hash = "sha256-DqLB1JnQ96cxVEzcXra1hFVfrN7eXTlTwPtlPClnaBA="; 84 }; 85 86 postPatch = '' 87 substituteInPlace pyproject.toml \ 88 --replace-fail "pybind11[global]" "pybind11" 89 ''; 90 91 build-system = [ 92 cmake 93 ninja 94 numpy 95 pybind11 96 setuptools 97 torch 98 ]; 99 dontUseCmakeConfigure = true; 100 101 dependencies = [ 102 cloudpickle 103 numpy 104 packaging 105 tensordict 106 pyvers 107 torch 108 ]; 109 110 optional-dependencies = { 111 atari = [ 112 gymnasium 113 ] 114 ++ gymnasium.optional-dependencies.atari; 115 brax = [ 116 brax 117 jax 118 ]; 119 checkpointing = [ torchsnapshot ]; 120 dm-control = [ dm-control ]; 121 gym-continuous = [ 122 gymnasium 123 mujoco 124 ]; 125 llm = [ 126 accelerate 127 datasets 128 einops 129 immutabledict 130 langdetect 131 nltk 132 playwright 133 protobuf 134 safetensors 135 sentencepiece 136 transformers 137 vllm 138 ]; 139 marl = [ 140 # dm-meltingpot (unpackaged) 141 pettingzoo 142 # vmas (unpackaged) 143 ]; 144 offline-data = [ 145 h5py 146 huggingface-hub 147 minari 148 pandas 149 pillow 150 requests 151 scikit-learn 152 torchvision 153 tqdm 154 ]; 155 open-spiel = [ 156 # open-spiel (unpackaged) 157 ]; 158 rendering = [ moviepy ]; 159 replay-buffer = [ torch ]; 160 utils = [ 161 git 162 hydra-core 163 # hydra-submitit-launcher (unpackaged) 164 tensorboard 165 tqdm 166 wandb 167 ]; 168 }; 169 170 # torchrl needs to create a folder to store datasets 171 preBuild = '' 172 export D4RL_DATASET_DIR=$(mktemp -d) 173 ''; 174 175 pythonImportsCheck = [ "torchrl" ]; 176 177 # We have to delete the source because otherwise it is used instead of the installed package. 178 preCheck = '' 179 rm -rf torchrl 180 181 export XDG_RUNTIME_DIR=$(mktemp -d) 182 ''; 183 184 nativeCheckInputs = [ 185 h5py 186 gymnasium 187 imageio 188 pytest-rerunfailures 189 pytestCheckHook 190 pyyaml 191 scipy 192 torchvision 193 ] 194 ++ optional-dependencies.atari 195 ++ optional-dependencies.gym-continuous 196 ++ optional-dependencies.llm 197 ++ optional-dependencies.rendering; 198 199 disabledTests = [ 200 # Require network 201 "test_create_or_load_dataset" 202 "test_from_text_env_tokenizer" 203 "test_from_text_env_tokenizer_catframes" 204 "test_from_text_rb_slicesampler" 205 "test_generate" 206 "test_get_dataloader" 207 "test_get_scores" 208 "test_preproc_data" 209 "test_prompt_tensordict_tokenizer" 210 "test_reward_model" 211 "test_tensordict_tokenizer" 212 "test_transform_compose" 213 "test_transform_model" 214 "test_transform_no_env" 215 "test_transform_rb" 216 217 # ray.exceptions.RuntimeEnvSetupError: Failed to set up runtime environment 218 "TestRayCollector" 219 220 # torchrl is incompatible with gymnasium>=1.0 221 # https://github.com/pytorch/rl/discussions/2483 222 "test_resetting_strategies" 223 "test_torchrl_to_gym" 224 "test_vecenvs_nan" 225 226 # gym.error.VersionNotFound: Environment version `v5` for environment `HalfCheetah` doesn't exist. 227 "test_collector_run" 228 "test_transform_inverse" 229 230 # OSError: Unable to synchronously create file (unable to truncate a file which is already open) 231 "test_multi_env" 232 "test_simple_env" 233 234 # ImportWarning: Ignoring non-library in plugin directory: 235 # /nix/store/cy8vwf1dacp3xfwnp9v6a1sz8bic8ylx-python3.12-mujoco-3.3.2/lib/python3.12/site-packages/mujoco/plugin/libmujoco.so.3.3.2 236 "test_auto_register" 237 "test_info_dict_reader" 238 239 # mujoco.FatalError: an OpenGL platform library has not been loaded into this process, this most likely means that a valid OpenGL context has not been created before mjr_makeContext was called 240 "test_vecenvs_env" 241 242 # ValueError: Can't write images with one color channel. 243 "test_log_video" 244 245 # Those tests require the ALE environments (provided by unpackaged shimmy) 246 "test_collector_env_reset" 247 "test_gym" 248 "test_gym_fake_td" 249 "test_recorder" 250 "test_recorder_load" 251 "test_rollout" 252 "test_parallel_trans_env_check" 253 "test_serial_trans_env_check" 254 "test_single_trans_env_check" 255 "test_td_creation_from_spec" 256 "test_trans_parallel_env_check" 257 "test_trans_serial_env_check" 258 "test_transform_env" 259 260 # undeterministic 261 "test_distributed_collector_updatepolicy" 262 "test_timeit" 263 264 # On a 24 threads system 265 # assert torch.get_num_threads() == max(1, init_threads - 3) 266 # AssertionError: assert 23 == 21 267 "test_auto_num_threads" 268 269 # Flaky (hangs indefinitely on some CPUs) 270 "test_gae_multidim" 271 "test_gae_param_as_tensor" 272 ] 273 ++ lib.optionals (stdenv.hostPlatform.isLinux && stdenv.hostPlatform.isAarch64) [ 274 # Flaky 275 # AssertionError: assert tensor([51.]) == ((5 * 11) + 2) 276 "test_vecnorm_parallel_auto" 277 ]; 278 279 disabledTestPaths = [ 280 # ERROR collecting test/smoke_test.py 281 # import file mismatch: 282 # imported module 'smoke_test' has this __file__ attribute: 283 # /build/source/test/llm/smoke_test.py 284 # which is not the same as the test file we want to collect: 285 # /build/source/test/smoke_test.py 286 "test/llm" 287 ]; 288 289 meta = { 290 description = "Modular, primitive-first, python-first PyTorch library for Reinforcement Learning"; 291 homepage = "https://github.com/pytorch/rl"; 292 changelog = "https://github.com/pytorch/rl/releases/tag/v${version}"; 293 license = lib.licenses.mit; 294 maintainers = with lib.maintainers; [ GaetanLepage ]; 295 }; 296}