1{
2 autoAddDriverRunpath,
3 buildPythonPackage,
4 config,
5 cudaPackages,
6 fetchFromGitHub,
7 fetchurl,
8 jax,
9 lib,
10 llvmPackages,
11 numpy,
12 pkgsBuildHost,
13 python,
14 replaceVars,
15 runCommand,
16 setuptools,
17 stdenv,
18 torch,
19 warp-lang, # Self-reference to this package for passthru.tests
20 writableTmpDirAsHomeHook,
21 writeShellApplication,
22
23 # Use standalone LLVM-based JIT compiler and CPU device support
24 standaloneSupport ? true,
25
26 # Use CUDA toolchain and GPU device support
27 cudaSupport ? config.cudaSupport,
28
29 # Build Warp with MathDx support (requires CUDA support)
30 # Most linear-algebra tile operations like tile_cholesky(), tile_fft(),
31 # and tile_matmul() require Warp to be built with the MathDx library.
32 # libmathdxSupport ? cudaSupport && stdenv.hostPlatform.isLinux,
33 libmathdxSupport ? cudaSupport,
34}@args:
35assert libmathdxSupport -> cudaSupport;
36let
37 effectiveStdenv = if cudaSupport then cudaPackages.backendStdenv else args.stdenv;
38 stdenv = builtins.throw "Use effectiveStdenv instead of stdenv directly, as it may be replaced by cudaPackages.backendStdenv";
39
40 version = "1.9.0";
41
42 libmathdx = effectiveStdenv.mkDerivation (finalAttrs: {
43 # NOTE: The version used should match the version Warp requires:
44 # https://github.com/NVIDIA/warp/blob/${version}/deps/libmathdx-deps.packman.xml
45 pname = "libmathdx";
46 version = "0.2.3";
47
48 outputs = [
49 "out"
50 "static"
51 ];
52
53 src =
54 let
55 baseURL = "https://developer.nvidia.com/downloads/compute/cublasdx/redist/cublasdx";
56 cudaMajorVersion = cudaPackages.cudaMajorVersion; # only 12, 13 supported
57 cudaVersion = "${cudaMajorVersion}.0"; # URL example: ${baseURL}/cuda12/${name}-${version}-cuda12.0.zip
58 name = lib.concatStringsSep "-" [
59 finalAttrs.pname
60 "Linux"
61 effectiveStdenv.hostPlatform.parsed.cpu.name
62 finalAttrs.version
63 "cuda${cudaVersion}"
64 ];
65
66 # nix-hash --type sha256 --to-sri $(nix-prefetch-url "https://...")
67 hashes = {
68 "12" = {
69 aarch64-linux = "sha256-d/aBC+zU2ciaw3isv33iuviXYaLGLdVDdzynGk9SFck=";
70 x86_64-linux = "sha256-CHIH0s4SnA67COtHBkwVCajW/3f0VxNBmuDLXy4LFIg=";
71 };
72 "13" = {
73 aarch64-linux = "sha256-TetJbMts8tpmj5PV4+jpnUHMcooDrXUEKL3aGWqilKI=";
74 x86_64-linux = "sha256-wLJLbRpQWa6QEm8ibm1gxt3mXvkWvu0vEzpnqTIvE1M=";
75 };
76 };
77 in
78 lib.mapNullable (
79 hash:
80 fetchurl {
81 inherit hash name;
82 url = "${baseURL}/cuda${cudaMajorVersion}/${name}.tar.gz";
83 }
84 ) (hashes.${cudaMajorVersion}.${effectiveStdenv.hostPlatform.system} or null);
85
86 dontUnpack = true;
87 dontConfigure = true;
88 dontBuild = true;
89
90 installPhase = ''
91 runHook preInstall
92
93 mkdir -p "$out"
94 tar -xzf "$src" -C "$out"
95
96 mkdir -p "$static"
97 moveToOutput "lib/libmathdx_static.a" "$static"
98
99 runHook postInstall
100 '';
101
102 meta = {
103 description = "Library used to integrate cuBLASDx and cuFFTDx into Warp";
104 homepage = "https://developer.nvidia.com/cublasdx-downloads";
105 sourceProvenance = with lib.sourceTypes; [ binaryNativeCode ];
106 license = with lib.licenses; [
107 # By downloading and using the software, you agree to fully
108 # comply with the terms and conditions of the NVIDIA Software
109 # License Agreement.
110 (
111 nvidiaCudaRedist
112 // {
113 url = "https://developer.download.nvidia.cn/compute/mathdx/License.txt";
114 }
115 )
116
117 # Some of the libmathdx routines were written by or derived
118 # from code written by Meta Platforms, Inc. and affiliates and
119 # are subject to the BSD License.
120 bsd3
121
122 # Some of the libmathdx routines were written by or derived from
123 # code written by Victor Zverovich and are subject to the following
124 # license:
125 mit
126 ];
127 platforms = [
128 "aarch64-linux"
129 "x86_64-linux"
130 ];
131 maintainers = with lib.maintainers; [ yzx9 ];
132 };
133 });
134in
135buildPythonPackage {
136 pname = "warp-lang";
137 inherit version;
138 pyproject = true;
139
140 # TODO(@connorbaker): Some CUDA setup hook is failing when __structuredAttrs is false,
141 # causing a bunch of missing math symbols (like expf) when linking against the static library
142 # provided by NVCC.
143 __structuredAttrs = true;
144
145 stdenv = effectiveStdenv;
146
147 src = fetchFromGitHub {
148 owner = "NVIDIA";
149 repo = "warp";
150 tag = "v${version}";
151 hash = "sha256-OEg2mUsEdRKhgx0fIraqme4moKNh1RSdN7/yCT1V5+g=";
152 };
153
154 patches =
155 lib.optionals effectiveStdenv.hostPlatform.isDarwin [
156 (replaceVars ./darwin-libcxx.patch {
157 LIBCXX_DEV = llvmPackages.libcxx.dev;
158 LIBCXX_LIB = llvmPackages.libcxx;
159 })
160 ./darwin-single-target.patch
161 ]
162 ++ lib.optionals standaloneSupport [
163 (replaceVars ./standalone-llvm.patch {
164 LLVM_DEV = llvmPackages.llvm.dev;
165 LLVM_LIB = llvmPackages.llvm.lib;
166 LIBCLANG_DEV = llvmPackages.libclang.dev;
167 LIBCLANG_LIB = llvmPackages.libclang.lib;
168 })
169 ./standalone-cxx11-abi.patch
170 ];
171
172 postPatch =
173 # Patch build_dll.py to use our gencode flags rather than NVIDIA's very broad defaults.
174 lib.optionalString cudaSupport ''
175 nixLog "patching $PWD/warp/build_dll.py to use our gencode flags"
176 substituteInPlace "$PWD/warp/build_dll.py" \
177 --replace-fail \
178 '*gencode_opts,' \
179 '${
180 lib.concatMapStringsSep ", " (gencodeString: ''"${gencodeString}"'') cudaPackages.flags.gencode
181 },' \
182 --replace-fail \
183 '*clang_arch_flags,' \
184 '${
185 lib.concatMapStringsSep ", " (
186 realArch: ''"--cuda-gpu-arch=${realArch}"''
187 ) cudaPackages.flags.realArches
188 },'
189 ''
190 # Patch build_dll.py to use dynamic libraries rather than static ones.
191 # NOTE: We do not patch the `nvptxcompiler_static` path because it is not available as a dynamic library.
192 + lib.optionalString cudaSupport ''
193 nixLog "patching $PWD/warp/build_dll.py to use dynamic libraries"
194 substituteInPlace "$PWD/warp/build_dll.py" \
195 --replace-fail \
196 '-lcudart_static' \
197 '-lcudart' \
198 --replace-fail \
199 '-lnvrtc_static' \
200 '-lnvrtc' \
201 --replace-fail \
202 '-lnvrtc-builtins_static' \
203 '-lnvrtc-builtins' \
204 --replace-fail \
205 '-lnvJitLink_static' \
206 '-lnvJitLink' \
207 --replace-fail \
208 '-lmathdx_static' \
209 '-lmathdx'
210 ''
211 # AssertionError: 0.4082476496696472 != 0.40824246406555176 within 5 places
212 + lib.optionalString effectiveStdenv.hostPlatform.isDarwin ''
213 nixLog "patching $PWD/warp/tests/test_fem.py to disable broken tests on darwin"
214 substituteInPlace "$PWD/warp/tests/test_codegen.py" \
215 --replace-fail \
216 'places=5' \
217 'places=4'
218 ''
219 # These tests fail on CPU and CUDA.
220 + ''
221 nixLog "patching $PWD/warp/tests/test_reload.py to disable broken tests"
222 substituteInPlace "$PWD/warp/tests/test_reload.py" \
223 --replace-fail \
224 'add_function_test(TestReload, "test_reload", test_reload, devices=devices)' \
225 "" \
226 --replace-fail \
227 'add_function_test(TestReload, "test_reload_references", test_reload_references, devices=get_test_devices("basic"))' \
228 ""
229 '';
230
231 build-system = [
232 setuptools
233 ];
234
235 dependencies = [
236 numpy
237 ];
238
239 # NOTE: While normally we wouldn't include autoAddDriverRunpath for packages built from source, since Warp
240 # will be loading GPU drivers at runtime, we need to inject the path to our video drivers.
241 nativeBuildInputs = lib.optionals cudaSupport [
242 autoAddDriverRunpath
243 cudaPackages.cuda_nvcc
244 ];
245
246 buildInputs =
247 lib.optionals standaloneSupport [
248 llvmPackages.llvm
249 llvmPackages.clang
250 llvmPackages.libcxx
251 ]
252 ++ lib.optionals cudaSupport [
253 (lib.getOutput "static" cudaPackages.cuda_nvcc) # dependency on nvptxcompiler_static; no dynamic version available
254 cudaPackages.cuda_cccl
255 cudaPackages.cuda_cudart
256 cudaPackages.cuda_nvcc
257 cudaPackages.cuda_nvrtc
258 ]
259 ++ lib.optionals libmathdxSupport [
260 libmathdx
261 cudaPackages.libcublas
262 cudaPackages.libcufft
263 cudaPackages.libcusolver
264 cudaPackages.libnvjitlink
265 ];
266
267 preBuild =
268 let
269 buildOptions =
270 lib.optionals effectiveStdenv.cc.isClang [
271 "--clang_build_toolchain"
272 ]
273 ++ lib.optionals (!standaloneSupport) [
274 "--no_standalone"
275 ]
276 ++ lib.optionals cudaSupport [
277 # NOTE: The `cuda_path` argument is the directory which contains `bin/nvcc` (i.e., the bin output).
278 "--cuda_path=${lib.getBin pkgsBuildHost.cudaPackages.cuda_nvcc}"
279 ]
280 ++ lib.optionals libmathdxSupport [
281 "--libmathdx"
282 "--libmathdx_path=${libmathdx}"
283 ]
284 ++ lib.optionals (!libmathdxSupport) [
285 "--no_libmathdx"
286 ];
287
288 buildOptionString = lib.concatStringsSep " " buildOptions;
289 in
290 ''
291 nixLog "running $PWD/build_lib.py to create components necessary to build the wheel"
292 "${python.pythonOnBuildForHost.interpreter}" "$PWD/build_lib.py" ${buildOptionString}
293 '';
294
295 pythonImportsCheck = [
296 "warp"
297 ];
298
299 # See passthru.tests.
300 doCheck = false;
301
302 passthru = {
303 # Make libmathdx available for introspection.
304 inherit libmathdx;
305
306 # Scripts which provide test packages and implement test logic.
307 testers.unit-tests = writeShellApplication {
308 name = "warp-lang-unit-tests";
309 runtimeInputs = [
310 # Use the references from args
311 (python.withPackages (_: [
312 warp-lang
313 jax
314 torch
315 ]))
316 # Disable paddlepaddle interop tests: malloc(): unaligned tcache chunk detected
317 # (paddlepaddle.override { inherit cudaSupport; })
318 ];
319 text = ''
320 python3 -m warp.tests
321 '';
322 };
323
324 # Tests run within the Nix sandbox.
325 tests =
326 let
327 mkUnitTests =
328 {
329 cudaSupport,
330 libmathdxSupport,
331 }:
332 let
333 name =
334 "warp-lang-unit-tests-cpu" # CPU is baseline
335 + lib.optionalString cudaSupport "-cuda"
336 + lib.optionalString libmathdxSupport "-libmathdx";
337
338 warp-lang' = warp-lang.override {
339 inherit cudaSupport libmathdxSupport;
340 # Make sure the warp-lang provided through callPackage is replaced with the override we're making.
341 warp-lang = warp-lang';
342 };
343 in
344 runCommand name
345 {
346 nativeBuildInputs = [
347 warp-lang'.passthru.testers.unit-tests
348 writableTmpDirAsHomeHook
349 ];
350 requiredSystemFeatures = lib.optionals cudaSupport [ "cuda" ];
351 }
352 ''
353 nixLog "running ${name}"
354
355 if warp-lang-unit-tests; then
356 nixLog "${name} passed"
357 touch "$out"
358 else
359 nixErrorLog "${name} failed"
360 exit 1
361 fi
362 '';
363 in
364 {
365 cpu = mkUnitTests {
366 cudaSupport = false;
367 libmathdxSupport = false;
368 };
369 cuda = {
370 cudaOnly = mkUnitTests {
371 cudaSupport = true;
372 libmathdxSupport = false;
373 };
374 cudaWithLibmathDx = mkUnitTests {
375 cudaSupport = true;
376 libmathdxSupport = true;
377 };
378 };
379 };
380 };
381
382 meta = {
383 description = "Python framework for high performance GPU simulation and graphics";
384 longDescription = ''
385 Warp is a Python framework for writing high-performance simulation
386 and graphics code. Warp takes regular Python functions and JIT
387 compiles them to efficient kernel code that can run on the CPU or
388 GPU.
389
390 Warp is designed for spatial computing and comes with a rich set
391 of primitives that make it easy to write programs for physics
392 simulation, perception, robotics, and geometry processing. In
393 addition, Warp kernels are differentiable and can be used as part
394 of machine-learning pipelines with frameworks such as PyTorch,
395 JAX and Paddle.
396 '';
397 homepage = "https://github.com/NVIDIA/warp";
398 changelog = "https://github.com/NVIDIA/warp/blob/v${version}/CHANGELOG.md";
399 license = lib.licenses.asl20;
400 platforms = with lib.platforms; linux ++ darwin;
401 maintainers = with lib.maintainers; [ yzx9 ];
402 };
403}