1{
2 lib,
3 stdenv,
4 buildPythonPackage,
5 fetchFromGitHub,
6
7 # build-system
8 pybind11,
9 setuptools,
10 setuptools-scm,
11
12 # nativeBuildInputs
13 cmake,
14 ninja,
15
16 # dependencies
17 cloudpickle,
18 importlib-metadata,
19 numpy,
20 orjson,
21 packaging,
22 pyvers,
23 torch,
24
25 # tests
26 h5py,
27 pytestCheckHook,
28}:
29
30buildPythonPackage rec {
31 pname = "tensordict";
32 version = "0.10.0";
33 pyproject = true;
34
35 src = fetchFromGitHub {
36 owner = "pytorch";
37 repo = "tensordict";
38 tag = "v${version}";
39 hash = "sha256-yxyA9BfN2hp1C3s+g2zBM2gVtckH3LV7luWw8DshFUs=";
40 };
41
42 postPatch = ''
43 substituteInPlace pyproject.toml \
44 --replace-fail "pybind11[global]" "pybind11"
45 '';
46
47 build-system = [
48 pybind11
49 setuptools
50 setuptools-scm
51 ];
52
53 nativeBuildInputs = [
54 cmake
55 ninja
56 ];
57 dontUseCmakeConfigure = true;
58
59 dependencies = [
60 cloudpickle
61 importlib-metadata
62 numpy
63 orjson
64 packaging
65 pyvers
66 torch
67 ];
68
69 pythonImportsCheck = [ "tensordict" ];
70
71 # We have to delete the source because otherwise it is used instead of the installed package.
72 preCheck = ''
73 rm -rf tensordict
74 '';
75
76 nativeCheckInputs = [
77 h5py
78 pytestCheckHook
79 ];
80
81 disabledTests = [
82 # FileNotFoundError: [Errno 2] No such file or directory: '/build/source/tensordict/tensorclass.pyi
83 "test_tensorclass_instance_methods"
84 "test_tensorclass_stub_methods"
85
86 # hangs forever on some CPUs
87 "test_map_iter_interrupt_early"
88 ]
89 ++ lib.optionals stdenv.hostPlatform.isDarwin [
90 # Hangs due to the use of a pool
91 "test_chunksize_num_chunks"
92 "test_index_with_generator"
93 "test_map_exception"
94 "test_map"
95 "test_multiprocessing"
96 ];
97
98 disabledTestPaths = [
99 # torch._dynamo.exc.Unsupported: Graph break due to unsupported builtin None.ReferenceType.__new__.
100 "test/test_compile.py"
101 ]
102 ++ lib.optionals stdenv.hostPlatform.isDarwin [
103 # Hangs forever
104 "test/test_distributed.py"
105 # Hangs after testing due to pool usage
106 "test/test_h5.py"
107 "test/test_memmap.py"
108 ];
109
110 meta = {
111 description = "Pytorch dedicated tensor container";
112 changelog = "https://github.com/pytorch/tensordict/releases/tag/${src.tag}";
113 homepage = "https://github.com/pytorch/tensordict";
114 license = lib.licenses.mit;
115 maintainers = with lib.maintainers; [ GaetanLepage ];
116 };
117}