1{
2 lib,
3 stdenv,
4 config,
5 buildPythonPackage,
6 fetchFromGitHub,
7
8 # patches
9 replaceVars,
10 addDriverRunpath,
11 cudaPackages,
12 llvmPackages,
13 ocl-icd,
14 rocmPackages,
15
16 # build-system
17 setuptools,
18
19 # optional-dependencies
20 llvmlite,
21 triton,
22 unicorn,
23
24 # tests
25 pytestCheckHook,
26 writableTmpDirAsHomeHook,
27 blobfile,
28 bottle,
29 capstone,
30 clang,
31 hexdump,
32 hypothesis,
33 jax,
34 librosa,
35 ml-dtypes,
36 networkx,
37 numpy,
38 onnx,
39 onnxruntime,
40 pillow,
41 pytest-xdist,
42 safetensors,
43 sentencepiece,
44 tiktoken,
45 torch,
46 tqdm,
47 transformers,
48 z3-solver,
49
50 # passthru
51 tinygrad,
52
53 cudaSupport ? config.cudaSupport,
54 rocmSupport ? config.rocmSupport,
55}:
56
57buildPythonPackage rec {
58 pname = "tinygrad";
59 version = "0.11.0";
60 pyproject = true;
61
62 src = fetchFromGitHub {
63 owner = "tinygrad";
64 repo = "tinygrad";
65 tag = "v${version}";
66 hash = "sha256-VG2rhkiwPFN3JYSBbqrwCdqhdGE8GY6oEatMSCydhw8=";
67 };
68
69 patches = [
70 (replaceVars ./fix-dlopen-cuda.patch {
71 inherit (addDriverRunpath) driverLink;
72 libnvrtc =
73 if cudaSupport then
74 "${lib.getLib cudaPackages.cuda_nvrtc}/lib/libnvrtc.so"
75 else
76 "Please import nixpkgs with `config.cudaSupport = true`";
77 })
78 ];
79
80 postPatch =
81 # Patch `clang` directly in the source file
82 # Use the unwrapped variant to enable the "native" features currently unavailable in the sandbox
83 ''
84 substituteInPlace tinygrad/runtime/ops_cpu.py \
85 --replace-fail "getenv(\"CC\", 'clang')" "'${lib.getExe llvmPackages.clang-unwrapped}'"
86 ''
87 + ''
88 substituteInPlace tinygrad/runtime/autogen/libc.py \
89 --replace-fail "ctypes.util.find_library('c')" "'${stdenv.cc.libc}/lib/libc.so.6'"
90 ''
91 + ''
92 substituteInPlace tinygrad/runtime/support/llvm.py \
93 --replace-fail "ctypes.util.find_library('LLVM')" "'${lib.getLib llvmPackages.llvm}/lib/libLLVM.so'"
94 ''
95 + lib.optionalString stdenv.hostPlatform.isLinux ''
96 substituteInPlace tinygrad/runtime/autogen/opencl.py \
97 --replace-fail "ctypes.util.find_library('OpenCL')" "'${ocl-icd}/lib/libOpenCL.so'"
98 ''
99 # test/test_tensor.py imports the PTX variable from the cuda_compiler.py file.
100 # This import leads to loading the libnvrtc.so library that is not substituted when cudaSupport = false.
101 # -> As a fix, we hardcode this variable to False
102 + lib.optionalString (!cudaSupport) ''
103 substituteInPlace test/test_tensor.py \
104 --replace-fail "from tinygrad.runtime.support.compiler_cuda import PTX" "PTX = False"
105 ''
106 # `cuda_fp16.h` and co. are needed at runtime to compile kernels
107 + lib.optionalString cudaSupport ''
108 substituteInPlace tinygrad/runtime/support/compiler_cuda.py \
109 --replace-fail \
110 '"-I/usr/local/cuda/include", "-I/usr/include", "-I/opt/cuda/include"' \
111 '"-I${lib.getDev cudaPackages.cuda_cudart}/include/"'
112 ''
113 + lib.optionalString rocmSupport ''
114 substituteInPlace tinygrad/runtime/autogen/hip.py \
115 --replace-fail "/opt/rocm/" "${rocmPackages.clr}/"
116
117 substituteInPlace tinygrad/runtime/support/compiler_hip.py \
118 --replace-fail "/opt/rocm/include" "${rocmPackages.clr}/include"
119
120 substituteInPlace tinygrad/runtime/support/compiler_hip.py \
121 --replace-fail "/opt/rocm/llvm" "${rocmPackages.llvm.llvm}"
122
123 substituteInPlace tinygrad/runtime/autogen/comgr.py \
124 --replace-fail "/opt/rocm/" "${rocmPackages.rocm-comgr}/"
125 '';
126
127 build-system = [ setuptools ];
128
129 optional-dependencies = {
130 llvm = [ llvmlite ];
131 arm = [ unicorn ];
132 triton = [ triton ];
133 };
134
135 pythonImportsCheck = [
136 "tinygrad"
137 ]
138 ++ lib.optionals cudaSupport [
139 "tinygrad.runtime.ops_nv"
140 ];
141
142 nativeCheckInputs = [
143 pytestCheckHook
144 writableTmpDirAsHomeHook
145
146 blobfile
147 bottle
148 capstone
149 clang
150 hexdump
151 hypothesis
152 jax
153 librosa
154 ml-dtypes
155 networkx
156 numpy
157 onnx
158 onnxruntime
159 pillow
160 pytest-xdist
161 safetensors
162 sentencepiece
163 tiktoken
164 torch
165 tqdm
166 transformers
167 z3-solver
168 ]
169 ++ networkx.optional-dependencies.extra;
170
171 disabledTests = [
172 # RuntimeError: Attempting to relocate against an undefined symbol 'fmaxf'
173 "test_backward_sum_acc_dtype"
174 "test_failure_27"
175
176 # Flaky:
177 # AssertionError: 2.1376906810000946 not less than 2.0
178 "test_recursive_pad"
179
180 # Require internet access
181 "testCopySHMtoDefault"
182 "test_benchmark_openpilot_model"
183 "test_bn_alone"
184 "test_bn_linear"
185 "test_bn_mnist"
186 "test_car"
187 "test_chicken"
188 "test_chicken_bigbatch"
189 "test_conv_mnist"
190 "test_data_parallel_resnet"
191 "test_dataset_is_realized"
192 "test_e2e_big"
193 "test_fetch_small"
194 "test_huggingface_enet_safetensors"
195 "test_index_mnist"
196 "test_linear_mnist"
197 "test_llama_basic"
198 "test_llama_bytes"
199 "test_llama_control_char"
200 "test_llama_early_tokenize"
201 "test_llama_pat"
202 "test_llama_repeat"
203 "test_llama_special1"
204 "test_llama_special2"
205 "test_load_convnext"
206 "test_load_enet"
207 "test_load_enet_alt"
208 "test_load_llama2bfloat"
209 "test_load_resnet"
210 "test_mnist_val"
211 "test_openpilot_model"
212 "test_resnet"
213 "test_shufflenet"
214 "test_transcribe_batch12"
215 "test_transcribe_batch21"
216 "test_transcribe_file1"
217 "test_transcribe_file2"
218 "test_transcribe_long"
219 "test_transcribe_long_no_batch"
220 "test_vgg7"
221 ]
222 ++ lib.optionals (stdenv.hostPlatform.system == "aarch64-linux") [
223 # Fail with AssertionError
224 "test_casts_from"
225 "test_casts_to"
226 "test_int8"
227 "test_int8_to_uint16_negative"
228 ];
229
230 disabledTestPaths = [
231 # Require internet access
232 "test/models/test_mnist.py"
233 "test/models/test_real_world.py"
234 "test/testextra/test_lr_scheduler.py"
235
236 # Files under this directory are not considered as tests by upstream and should be skipped
237 "extra/"
238 ];
239
240 passthru.tests = {
241 withCuda = tinygrad.override { cudaSupport = true; };
242 };
243
244 meta = {
245 description = "Simple and powerful neural network framework";
246 homepage = "https://github.com/tinygrad/tinygrad";
247 changelog = "https://github.com/tinygrad/tinygrad/releases/tag/v${version}";
248 license = lib.licenses.mit;
249 maintainers = with lib.maintainers; [ GaetanLepage ];
250 badPlatforms = [
251 # Fatal Python error: Aborted
252 # onnxruntime/capi/_pybind_state.py", line 32 in <module>
253 "aarch64-linux"
254
255 # Tests segfault on darwin
256 lib.systems.inspect.patterns.isDarwin
257 ];
258 };
259}