at master 6.6 kB view raw
1{ 2 lib, 3 stdenv, 4 config, 5 buildPythonPackage, 6 fetchFromGitHub, 7 8 # patches 9 replaceVars, 10 addDriverRunpath, 11 cudaPackages, 12 llvmPackages, 13 ocl-icd, 14 rocmPackages, 15 16 # build-system 17 setuptools, 18 19 # optional-dependencies 20 llvmlite, 21 triton, 22 unicorn, 23 24 # tests 25 pytestCheckHook, 26 writableTmpDirAsHomeHook, 27 blobfile, 28 bottle, 29 capstone, 30 clang, 31 hexdump, 32 hypothesis, 33 jax, 34 librosa, 35 ml-dtypes, 36 networkx, 37 numpy, 38 onnx, 39 onnxruntime, 40 pillow, 41 pytest-xdist, 42 safetensors, 43 sentencepiece, 44 tiktoken, 45 torch, 46 tqdm, 47 transformers, 48 z3-solver, 49 50 # passthru 51 tinygrad, 52 53 cudaSupport ? config.cudaSupport, 54 rocmSupport ? config.rocmSupport, 55}: 56 57buildPythonPackage rec { 58 pname = "tinygrad"; 59 version = "0.11.0"; 60 pyproject = true; 61 62 src = fetchFromGitHub { 63 owner = "tinygrad"; 64 repo = "tinygrad"; 65 tag = "v${version}"; 66 hash = "sha256-VG2rhkiwPFN3JYSBbqrwCdqhdGE8GY6oEatMSCydhw8="; 67 }; 68 69 patches = [ 70 (replaceVars ./fix-dlopen-cuda.patch { 71 inherit (addDriverRunpath) driverLink; 72 libnvrtc = 73 if cudaSupport then 74 "${lib.getLib cudaPackages.cuda_nvrtc}/lib/libnvrtc.so" 75 else 76 "Please import nixpkgs with `config.cudaSupport = true`"; 77 }) 78 ]; 79 80 postPatch = 81 # Patch `clang` directly in the source file 82 # Use the unwrapped variant to enable the "native" features currently unavailable in the sandbox 83 '' 84 substituteInPlace tinygrad/runtime/ops_cpu.py \ 85 --replace-fail "getenv(\"CC\", 'clang')" "'${lib.getExe llvmPackages.clang-unwrapped}'" 86 '' 87 + '' 88 substituteInPlace tinygrad/runtime/autogen/libc.py \ 89 --replace-fail "ctypes.util.find_library('c')" "'${stdenv.cc.libc}/lib/libc.so.6'" 90 '' 91 + '' 92 substituteInPlace tinygrad/runtime/support/llvm.py \ 93 --replace-fail "ctypes.util.find_library('LLVM')" "'${lib.getLib llvmPackages.llvm}/lib/libLLVM.so'" 94 '' 95 + lib.optionalString stdenv.hostPlatform.isLinux '' 96 substituteInPlace tinygrad/runtime/autogen/opencl.py \ 97 --replace-fail "ctypes.util.find_library('OpenCL')" "'${ocl-icd}/lib/libOpenCL.so'" 98 '' 99 # test/test_tensor.py imports the PTX variable from the cuda_compiler.py file. 100 # This import leads to loading the libnvrtc.so library that is not substituted when cudaSupport = false. 101 # -> As a fix, we hardcode this variable to False 102 + lib.optionalString (!cudaSupport) '' 103 substituteInPlace test/test_tensor.py \ 104 --replace-fail "from tinygrad.runtime.support.compiler_cuda import PTX" "PTX = False" 105 '' 106 # `cuda_fp16.h` and co. are needed at runtime to compile kernels 107 + lib.optionalString cudaSupport '' 108 substituteInPlace tinygrad/runtime/support/compiler_cuda.py \ 109 --replace-fail \ 110 '"-I/usr/local/cuda/include", "-I/usr/include", "-I/opt/cuda/include"' \ 111 '"-I${lib.getDev cudaPackages.cuda_cudart}/include/"' 112 '' 113 + lib.optionalString rocmSupport '' 114 substituteInPlace tinygrad/runtime/autogen/hip.py \ 115 --replace-fail "/opt/rocm/" "${rocmPackages.clr}/" 116 117 substituteInPlace tinygrad/runtime/support/compiler_hip.py \ 118 --replace-fail "/opt/rocm/include" "${rocmPackages.clr}/include" 119 120 substituteInPlace tinygrad/runtime/support/compiler_hip.py \ 121 --replace-fail "/opt/rocm/llvm" "${rocmPackages.llvm.llvm}" 122 123 substituteInPlace tinygrad/runtime/autogen/comgr.py \ 124 --replace-fail "/opt/rocm/" "${rocmPackages.rocm-comgr}/" 125 ''; 126 127 build-system = [ setuptools ]; 128 129 optional-dependencies = { 130 llvm = [ llvmlite ]; 131 arm = [ unicorn ]; 132 triton = [ triton ]; 133 }; 134 135 pythonImportsCheck = [ 136 "tinygrad" 137 ] 138 ++ lib.optionals cudaSupport [ 139 "tinygrad.runtime.ops_nv" 140 ]; 141 142 nativeCheckInputs = [ 143 pytestCheckHook 144 writableTmpDirAsHomeHook 145 146 blobfile 147 bottle 148 capstone 149 clang 150 hexdump 151 hypothesis 152 jax 153 librosa 154 ml-dtypes 155 networkx 156 numpy 157 onnx 158 onnxruntime 159 pillow 160 pytest-xdist 161 safetensors 162 sentencepiece 163 tiktoken 164 torch 165 tqdm 166 transformers 167 z3-solver 168 ] 169 ++ networkx.optional-dependencies.extra; 170 171 disabledTests = [ 172 # RuntimeError: Attempting to relocate against an undefined symbol 'fmaxf' 173 "test_backward_sum_acc_dtype" 174 "test_failure_27" 175 176 # Flaky: 177 # AssertionError: 2.1376906810000946 not less than 2.0 178 "test_recursive_pad" 179 180 # Require internet access 181 "testCopySHMtoDefault" 182 "test_benchmark_openpilot_model" 183 "test_bn_alone" 184 "test_bn_linear" 185 "test_bn_mnist" 186 "test_car" 187 "test_chicken" 188 "test_chicken_bigbatch" 189 "test_conv_mnist" 190 "test_data_parallel_resnet" 191 "test_dataset_is_realized" 192 "test_e2e_big" 193 "test_fetch_small" 194 "test_huggingface_enet_safetensors" 195 "test_index_mnist" 196 "test_linear_mnist" 197 "test_llama_basic" 198 "test_llama_bytes" 199 "test_llama_control_char" 200 "test_llama_early_tokenize" 201 "test_llama_pat" 202 "test_llama_repeat" 203 "test_llama_special1" 204 "test_llama_special2" 205 "test_load_convnext" 206 "test_load_enet" 207 "test_load_enet_alt" 208 "test_load_llama2bfloat" 209 "test_load_resnet" 210 "test_mnist_val" 211 "test_openpilot_model" 212 "test_resnet" 213 "test_shufflenet" 214 "test_transcribe_batch12" 215 "test_transcribe_batch21" 216 "test_transcribe_file1" 217 "test_transcribe_file2" 218 "test_transcribe_long" 219 "test_transcribe_long_no_batch" 220 "test_vgg7" 221 ] 222 ++ lib.optionals (stdenv.hostPlatform.system == "aarch64-linux") [ 223 # Fail with AssertionError 224 "test_casts_from" 225 "test_casts_to" 226 "test_int8" 227 "test_int8_to_uint16_negative" 228 ]; 229 230 disabledTestPaths = [ 231 # Require internet access 232 "test/models/test_mnist.py" 233 "test/models/test_real_world.py" 234 "test/testextra/test_lr_scheduler.py" 235 236 # Files under this directory are not considered as tests by upstream and should be skipped 237 "extra/" 238 ]; 239 240 passthru.tests = { 241 withCuda = tinygrad.override { cudaSupport = true; }; 242 }; 243 244 meta = { 245 description = "Simple and powerful neural network framework"; 246 homepage = "https://github.com/tinygrad/tinygrad"; 247 changelog = "https://github.com/tinygrad/tinygrad/releases/tag/v${version}"; 248 license = lib.licenses.mit; 249 maintainers = with lib.maintainers; [ GaetanLepage ]; 250 badPlatforms = [ 251 # Fatal Python error: Aborted 252 # onnxruntime/capi/_pybind_state.py", line 32 in <module> 253 "aarch64-linux" 254 255 # Tests segfault on darwin 256 lib.systems.inspect.patterns.isDarwin 257 ]; 258 }; 259}