1{
2 lib,
3 stdenv,
4 buildPythonPackage,
5 fetchFromGitHub,
6
7 # build-system
8 cmake,
9 ninja,
10 numpy,
11 pybind11,
12 setuptools,
13 torch,
14
15 # dependencies
16 cloudpickle,
17 packaging,
18 pyvers,
19 tensordict,
20
21 # optional-dependencies
22 # atari
23 gymnasium,
24 # brax
25 brax,
26 jax,
27 # checkpointing
28 torchsnapshot,
29 # dm-control
30 dm-control,
31 # gym-continuous
32 mujoco,
33 # llm
34 accelerate,
35 datasets,
36 einops,
37 immutabledict,
38 langdetect,
39 nltk,
40 playwright,
41 protobuf,
42 safetensors,
43 sentencepiece,
44 transformers,
45 vllm,
46 # marl
47 pettingzoo,
48 # offline-data
49 h5py,
50 huggingface-hub,
51 minari,
52 pandas,
53 pillow,
54 requests,
55 scikit-learn,
56 torchvision,
57 tqdm,
58 # rendering
59 moviepy,
60 # utils
61 git,
62 hydra-core,
63 tensorboard,
64 wandb,
65
66 # tests
67 imageio,
68 pytest-rerunfailures,
69 pytestCheckHook,
70 pyyaml,
71 scipy,
72}:
73
74buildPythonPackage rec {
75 pname = "torchrl";
76 version = "0.10.0";
77 pyproject = true;
78
79 src = fetchFromGitHub {
80 owner = "pytorch";
81 repo = "rl";
82 tag = "v${version}";
83 hash = "sha256-DqLB1JnQ96cxVEzcXra1hFVfrN7eXTlTwPtlPClnaBA=";
84 };
85
86 postPatch = ''
87 substituteInPlace pyproject.toml \
88 --replace-fail "pybind11[global]" "pybind11"
89 '';
90
91 build-system = [
92 cmake
93 ninja
94 numpy
95 pybind11
96 setuptools
97 torch
98 ];
99 dontUseCmakeConfigure = true;
100
101 dependencies = [
102 cloudpickle
103 numpy
104 packaging
105 tensordict
106 pyvers
107 torch
108 ];
109
110 optional-dependencies = {
111 atari = [
112 gymnasium
113 ]
114 ++ gymnasium.optional-dependencies.atari;
115 brax = [
116 brax
117 jax
118 ];
119 checkpointing = [ torchsnapshot ];
120 dm-control = [ dm-control ];
121 gym-continuous = [
122 gymnasium
123 mujoco
124 ];
125 llm = [
126 accelerate
127 datasets
128 einops
129 immutabledict
130 langdetect
131 nltk
132 playwright
133 protobuf
134 safetensors
135 sentencepiece
136 transformers
137 vllm
138 ];
139 marl = [
140 # dm-meltingpot (unpackaged)
141 pettingzoo
142 # vmas (unpackaged)
143 ];
144 offline-data = [
145 h5py
146 huggingface-hub
147 minari
148 pandas
149 pillow
150 requests
151 scikit-learn
152 torchvision
153 tqdm
154 ];
155 open-spiel = [
156 # open-spiel (unpackaged)
157 ];
158 rendering = [ moviepy ];
159 replay-buffer = [ torch ];
160 utils = [
161 git
162 hydra-core
163 # hydra-submitit-launcher (unpackaged)
164 tensorboard
165 tqdm
166 wandb
167 ];
168 };
169
170 # torchrl needs to create a folder to store datasets
171 preBuild = ''
172 export D4RL_DATASET_DIR=$(mktemp -d)
173 '';
174
175 pythonImportsCheck = [ "torchrl" ];
176
177 # We have to delete the source because otherwise it is used instead of the installed package.
178 preCheck = ''
179 rm -rf torchrl
180
181 export XDG_RUNTIME_DIR=$(mktemp -d)
182 '';
183
184 nativeCheckInputs = [
185 h5py
186 gymnasium
187 imageio
188 pytest-rerunfailures
189 pytestCheckHook
190 pyyaml
191 scipy
192 torchvision
193 ]
194 ++ optional-dependencies.atari
195 ++ optional-dependencies.gym-continuous
196 ++ optional-dependencies.llm
197 ++ optional-dependencies.rendering;
198
199 disabledTests = [
200 # Require network
201 "test_create_or_load_dataset"
202 "test_from_text_env_tokenizer"
203 "test_from_text_env_tokenizer_catframes"
204 "test_from_text_rb_slicesampler"
205 "test_generate"
206 "test_get_dataloader"
207 "test_get_scores"
208 "test_preproc_data"
209 "test_prompt_tensordict_tokenizer"
210 "test_reward_model"
211 "test_tensordict_tokenizer"
212 "test_transform_compose"
213 "test_transform_model"
214 "test_transform_no_env"
215 "test_transform_rb"
216
217 # ray.exceptions.RuntimeEnvSetupError: Failed to set up runtime environment
218 "TestRayCollector"
219
220 # torchrl is incompatible with gymnasium>=1.0
221 # https://github.com/pytorch/rl/discussions/2483
222 "test_resetting_strategies"
223 "test_torchrl_to_gym"
224 "test_vecenvs_nan"
225
226 # gym.error.VersionNotFound: Environment version `v5` for environment `HalfCheetah` doesn't exist.
227 "test_collector_run"
228 "test_transform_inverse"
229
230 # OSError: Unable to synchronously create file (unable to truncate a file which is already open)
231 "test_multi_env"
232 "test_simple_env"
233
234 # ImportWarning: Ignoring non-library in plugin directory:
235 # /nix/store/cy8vwf1dacp3xfwnp9v6a1sz8bic8ylx-python3.12-mujoco-3.3.2/lib/python3.12/site-packages/mujoco/plugin/libmujoco.so.3.3.2
236 "test_auto_register"
237 "test_info_dict_reader"
238
239 # mujoco.FatalError: an OpenGL platform library has not been loaded into this process, this most likely means that a valid OpenGL context has not been created before mjr_makeContext was called
240 "test_vecenvs_env"
241
242 # ValueError: Can't write images with one color channel.
243 "test_log_video"
244
245 # Those tests require the ALE environments (provided by unpackaged shimmy)
246 "test_collector_env_reset"
247 "test_gym"
248 "test_gym_fake_td"
249 "test_recorder"
250 "test_recorder_load"
251 "test_rollout"
252 "test_parallel_trans_env_check"
253 "test_serial_trans_env_check"
254 "test_single_trans_env_check"
255 "test_td_creation_from_spec"
256 "test_trans_parallel_env_check"
257 "test_trans_serial_env_check"
258 "test_transform_env"
259
260 # undeterministic
261 "test_distributed_collector_updatepolicy"
262 "test_timeit"
263
264 # On a 24 threads system
265 # assert torch.get_num_threads() == max(1, init_threads - 3)
266 # AssertionError: assert 23 == 21
267 "test_auto_num_threads"
268
269 # Flaky (hangs indefinitely on some CPUs)
270 "test_gae_multidim"
271 "test_gae_param_as_tensor"
272 ]
273 ++ lib.optionals (stdenv.hostPlatform.isLinux && stdenv.hostPlatform.isAarch64) [
274 # Flaky
275 # AssertionError: assert tensor([51.]) == ((5 * 11) + 2)
276 "test_vecnorm_parallel_auto"
277 ];
278
279 disabledTestPaths = [
280 # ERROR collecting test/smoke_test.py
281 # import file mismatch:
282 # imported module 'smoke_test' has this __file__ attribute:
283 # /build/source/test/llm/smoke_test.py
284 # which is not the same as the test file we want to collect:
285 # /build/source/test/smoke_test.py
286 "test/llm"
287 ];
288
289 meta = {
290 description = "Modular, primitive-first, python-first PyTorch library for Reinforcement Learning";
291 homepage = "https://github.com/pytorch/rl";
292 changelog = "https://github.com/pytorch/rl/releases/tag/v${version}";
293 license = lib.licenses.mit;
294 maintainers = with lib.maintainers; [ GaetanLepage ];
295 };
296}