1{
2 stdenv,
3 lib,
4 fetchFromGitHub,
5 fetchFromGitLab,
6 fetchpatch,
7 git-unroll,
8 buildPythonPackage,
9 python,
10 runCommand,
11 writeShellScript,
12 config,
13 cudaSupport ? config.cudaSupport,
14 cudaPackages,
15 autoAddDriverRunpath,
16 effectiveMagma ?
17 if cudaSupport then
18 magma-cuda-static
19 else if rocmSupport then
20 magma-hip
21 else
22 magma,
23 magma,
24 magma-hip,
25 magma-cuda-static,
26 # Use the system NCCL as long as we're targeting CUDA on a supported platform.
27 useSystemNccl ? (cudaSupport && !cudaPackages.nccl.meta.unsupported || rocmSupport),
28 MPISupport ? false,
29 mpi,
30 buildDocs ? false,
31
32 # tests.cudaAvailable:
33 callPackage,
34
35 # Native build inputs
36 cmake,
37 symlinkJoin,
38 which,
39 pybind11,
40 pkg-config,
41 removeReferencesTo,
42
43 # Build inputs
44 apple-sdk_13,
45 openssl,
46 numactl,
47 llvmPackages,
48
49 # dependencies
50 astunparse,
51 binutils,
52 expecttest,
53 filelock,
54 fsspec,
55 hypothesis,
56 jinja2,
57 networkx,
58 packaging,
59 psutil,
60 pyyaml,
61 requests,
62 sympy,
63 types-dataclasses,
64 typing-extensions,
65 # ROCm build and `torch.compile` requires `triton`
66 tritonSupport ? (!stdenv.hostPlatform.isDarwin),
67 triton,
68
69 # TODO: 1. callPackage needs to learn to distinguish between the task
70 # of "asking for an attribute from the parent scope" and
71 # the task of "exposing a formal parameter in .override".
72 # TODO: 2. We should probably abandon attributes such as `torchWithCuda` (etc.)
73 # as they routinely end up consuming the wrong arguments\
74 # (dependencies without cuda support).
75 # Instead we should rely on overlays and nixpkgsFun.
76 # (@SomeoneSerge)
77 _tritonEffective ? if cudaSupport then triton-cuda else triton,
78 triton-cuda,
79
80 # Disable MKLDNN on aarch64-darwin, it negatively impacts performance,
81 # this is also what official pytorch build does
82 mklDnnSupport ? !(stdenv.hostPlatform.isDarwin && stdenv.hostPlatform.isAarch64),
83
84 # virtual pkg that consistently instantiates blas across nixpkgs
85 # See https://github.com/NixOS/nixpkgs/pull/83888
86 blas,
87
88 # ninja (https://ninja-build.org) must be available to run C++ extensions tests,
89 ninja,
90
91 # dependencies for torch.utils.tensorboard
92 pillow,
93 six,
94 tensorboard,
95 protobuf,
96
97 # ROCm dependencies
98 rocmSupport ? config.rocmSupport,
99 rocmPackages,
100 gpuTargets ? [ ],
101
102 vulkanSupport ? false,
103 vulkan-headers,
104 vulkan-loader,
105 shaderc,
106}:
107
108let
109 inherit (lib)
110 attrsets
111 lists
112 strings
113 trivial
114 ;
115 inherit (cudaPackages) cudnn flags nccl;
116
117 triton = throw "python3Packages.torch: use _tritonEffective instead of triton to avoid divergence";
118
119 setBool = v: if v then "1" else "0";
120
121 # https://github.com/pytorch/pytorch/blob/v2.8.0/torch/utils/cpp_extension.py#L2411-L2414
122 supportedTorchCudaCapabilities =
123 let
124 real = [
125 "3.5"
126 "3.7"
127 "5.0"
128 "5.2"
129 "5.3"
130 "6.0"
131 "6.1"
132 "6.2"
133 "7.0"
134 "7.2"
135 "7.5"
136 "8.0"
137 "8.6"
138 "8.7"
139 "8.9"
140 "9.0"
141 "9.0a"
142 "10.0"
143 "10.0"
144 "10.0a"
145 "10.1"
146 "10.1a"
147 "10.3"
148 "10.3a"
149 "12.0"
150 "12.0a"
151 "12.1"
152 "12.1a"
153 ];
154 ptx = lists.map (x: "${x}+PTX") real;
155 in
156 real ++ ptx;
157
158 # NOTE: The lists.subtractLists function is perhaps a bit unintuitive. It subtracts the elements
159 # of the first list *from* the second list. That means:
160 # lists.subtractLists a b = b - a
161
162 # For CUDA
163 supportedCudaCapabilities = lists.intersectLists flags.cudaCapabilities supportedTorchCudaCapabilities;
164 unsupportedCudaCapabilities = lists.subtractLists supportedCudaCapabilities flags.cudaCapabilities;
165
166 isCudaJetson = cudaSupport && cudaPackages.flags.isJetsonBuild;
167
168 # Use trivial.warnIf to print a warning if any unsupported GPU targets are specified.
169 gpuArchWarner =
170 supported: unsupported:
171 trivial.throwIf (supported == [ ]) (
172 "No supported GPU targets specified. Requested GPU targets: "
173 + strings.concatStringsSep ", " unsupported
174 ) supported;
175
176 # Create the gpuTargetString.
177 gpuTargetString = strings.concatStringsSep ";" (
178 if gpuTargets != [ ] then
179 # If gpuTargets is specified, it always takes priority.
180 gpuTargets
181 else if cudaSupport then
182 gpuArchWarner supportedCudaCapabilities unsupportedCudaCapabilities
183 else if rocmSupport then
184 lib.lists.subtractLists [
185 # Remove RDNA1 gfx101x archs from default ROCm support list to avoid
186 # use of undeclared identifier 'CK_BUFFER_RESOURCE_3RD_DWORD'
187 # TODO: Retest after ROCm 6.4 or torch 2.8
188 "gfx1010"
189 "gfx1012"
190
191 # Strix Halo seems to be broken as well, see
192 # https://github.com/NixOS/nixpkgs/pull/440359.
193 "gfx1151"
194 ] (rocmPackages.clr.localGpuTargets or rocmPackages.clr.gpuTargets)
195 else
196 throw "No GPU targets specified"
197 );
198
199 rocmtoolkit_joined = symlinkJoin {
200 name = "rocm-merged";
201
202 paths = with rocmPackages; [
203 rocm-core
204 clr
205 rccl
206 miopen
207 aotriton
208 composable_kernel
209 rocrand
210 rocblas
211 rocsparse
212 hipsparse
213 rocthrust
214 rocprim
215 hipcub
216 roctracer
217 rocfft
218 rocsolver
219 hipfft
220 hiprand
221 hipsolver
222 hipblas-common
223 hipblas
224 hipblaslt
225 rocminfo
226 rocm-comgr
227 rocm-device-libs
228 rocm-runtime
229 rocm-smi
230 clr.icd
231 hipify
232 ];
233
234 # Fix `setuptools` not being found
235 postBuild = ''
236 rm -rf $out/nix-support
237 '';
238 };
239
240 brokenConditions = attrsets.filterAttrs (_: cond: cond) {
241 "CUDA and ROCm are mutually exclusive" = cudaSupport && rocmSupport;
242 "CUDA is not targeting Linux" = cudaSupport && !stdenv.hostPlatform.isLinux;
243 "Unsupported CUDA version" =
244 cudaSupport
245 && !(builtins.elem cudaPackages.cudaMajorVersion [
246 "11"
247 "12"
248 ]);
249 "MPI cudatoolkit does not match cudaPackages.cudatoolkit" =
250 MPISupport && cudaSupport && (mpi.cudatoolkit != cudaPackages.cudatoolkit);
251 # This used to be a deep package set comparison between cudaPackages and
252 # effectiveMagma.cudaPackages, making torch too strict in cudaPackages.
253 # In particular, this triggered warnings from cuda's `aliases.nix`
254 "Magma cudaPackages does not match cudaPackages" =
255 cudaSupport
256 && (effectiveMagma.cudaPackages.cudaMajorMinorVersion != cudaPackages.cudaMajorMinorVersion);
257 };
258
259 unroll-src = writeShellScript "unroll-src" ''
260 echo "{
261 version,
262 fetchFromGitLab,
263 fetchFromGitHub,
264 runCommand,
265 }:
266 assert version == "'"'$1'"'";"
267 ${lib.getExe git-unroll} https://github.com/pytorch/pytorch v$1
268 echo
269 echo "# Update using: unroll-src [version]"
270 '';
271
272 stdenv' = if cudaSupport then cudaPackages.backendStdenv else stdenv;
273in
274buildPythonPackage rec {
275 pname = "torch";
276 # Don't forget to update torch-bin to the same version.
277 version = "2.8.0";
278 pyproject = true;
279
280 stdenv = stdenv';
281
282 outputs = [
283 "out" # output standard python package
284 "dev" # output libtorch headers
285 "lib" # output libtorch libraries
286 "cxxdev" # propagated deps for the cmake consumers of torch
287 ];
288 cudaPropagateToOutput = "cxxdev";
289
290 src = callPackage ./src.nix {
291 inherit
292 version
293 fetchFromGitHub
294 fetchFromGitLab
295 runCommand
296 ;
297 };
298
299 patches = [
300 ./clang19-template-warning.patch
301
302 # Do not override PYTHONPATH, otherwise, the build fails with:
303 # ModuleNotFoundError: No module named 'typing_extensions'
304 (fetchpatch {
305 name = "cmake-build-preserve-PYTHONPATH";
306 url = "https://github.com/pytorch/pytorch/commit/231c72240d80091f099c95e326d3600cba866eee.patch";
307 hash = "sha256-BBCjxzz2TUkx4nXRyRILA82kMwyb/4+C3eOtYqf5dhk=";
308 })
309
310 # Fixes GCC-14 compatibility on ARM
311 # Adapted from https://github.com/pytorch/pytorch/pull/157867
312 # TODO: remove at the next release
313 ./gcc-14-arm-compat.path
314 ]
315 ++ lib.optionals cudaSupport [
316 ./fix-cmake-cuda-toolkit.patch
317 ./nvtx3-hpp-path-fix.patch
318 ]
319 ++ lib.optionals stdenv.hostPlatform.isLinux [
320 # Propagate CUPTI to Kineto by overriding the search path with environment variables.
321 # https://github.com/pytorch/pytorch/pull/108847
322 ./pytorch-pr-108847.patch
323 ]
324 ++ lib.optionals (lib.getName blas.provider == "mkl") [
325 # The CMake install tries to add some hardcoded rpaths, incompatible
326 # with the Nix store, which fails. Simply remove this step to get
327 # rpaths that point to the Nix store.
328 ./disable-cmake-mkl-rpath.patch
329 ];
330
331 postPatch = ''
332 substituteInPlace pyproject.toml \
333 --replace-fail "setuptools>=62.3.0,<80.0" "setuptools"
334 ''
335 # Provide path to openssl binary for inductor code cache hash
336 # InductorError: FileNotFoundError: [Errno 2] No such file or directory: 'openssl'
337 + ''
338 substituteInPlace torch/_inductor/codecache.py \
339 --replace-fail '"openssl"' '"${lib.getExe openssl}"'
340 ''
341 + ''
342 substituteInPlace cmake/public/cuda.cmake \
343 --replace-fail \
344 'message(FATAL_ERROR "Found two conflicting CUDA' \
345 'message(WARNING "Found two conflicting CUDA' \
346 --replace-warn \
347 "set(CUDAToolkit_ROOT" \
348 "# Upstream: set(CUDAToolkit_ROOT"
349 substituteInPlace third_party/gloo/cmake/Cuda.cmake \
350 --replace-warn "find_package(CUDAToolkit 7.0" "find_package(CUDAToolkit"
351 ''
352 # annotations (3.7), print_function (3.0), with_statement (2.6) are all supported
353 + ''
354 sed -i -e "/from __future__ import/d" **.py
355 substituteInPlace third_party/NNPACK/CMakeLists.txt \
356 --replace-fail "PYTHONPATH=" 'PYTHONPATH=$ENV{PYTHONPATH}:'
357 ''
358 # flag from cmakeFlags doesn't work, not clear why
359 # setting it at the top of NNPACK's own CMakeLists does
360 + ''
361 sed -i '2s;^;set(PYTHON_SIX_SOURCE_DIR ${six.src})\n;' third_party/NNPACK/CMakeLists.txt
362 ''
363 # Ensure that torch profiler unwind uses addr2line from nix
364 + ''
365 substituteInPlace torch/csrc/profiler/unwind/unwind.cpp \
366 --replace-fail 'addr2line_binary_ = "addr2line"' 'addr2line_binary_ = "${lib.getExe' binutils "addr2line"}"'
367 ''
368 + lib.optionalString rocmSupport ''
369 # https://github.com/facebookincubator/gloo/pull/297
370 substituteInPlace third_party/gloo/cmake/Hipify.cmake \
371 --replace-fail "\''${HIPIFY_COMMAND}" "python \''${HIPIFY_COMMAND}"
372
373 # Doesn't pick up the environment variable?
374 substituteInPlace third_party/kineto/libkineto/CMakeLists.txt \
375 --replace-fail "\''$ENV{ROCM_SOURCE_DIR}" "${rocmtoolkit_joined}"
376
377 # Use composable kernel as dependency, rather than built-in third-party
378 substituteInPlace aten/src/ATen/CMakeLists.txt \
379 --replace-fail "list(APPEND ATen_HIP_INCLUDE \''${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include)" "" \
380 --replace-fail "list(APPEND ATen_HIP_INCLUDE \''${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include)" ""
381 ''
382 # Detection of NCCL version doesn't work particularly well when using the static binary.
383 + lib.optionalString cudaSupport ''
384 substituteInPlace cmake/Modules/FindNCCL.cmake \
385 --replace-fail \
386 'message(FATAL_ERROR "Found NCCL header version and library version' \
387 'message(WARNING "Found NCCL header version and library version'
388 ''
389 # Remove PyTorch's FindCUDAToolkit.cmake and use CMake's default.
390 # NOTE: Parts of pytorch rely on unmaintained FindCUDA.cmake with custom patches to support e.g.
391 # newer architectures (sm_90a). We do want to delete vendored patches, but have to keep them
392 # until https://github.com/pytorch/pytorch/issues/76082 is addressed
393 + lib.optionalString cudaSupport ''
394 rm cmake/Modules/FindCUDAToolkit.cmake
395 '';
396
397 # NOTE(@connorbaker): Though we do not disable Gloo or MPI when building with CUDA support, caution should be taken
398 # when using the different backends. Gloo's GPU support isn't great, and MPI and CUDA can't be used at the same time
399 # without extreme care to ensure they don't lock each other out of shared resources.
400 # For more, see https://github.com/open-mpi/ompi/issues/7733#issuecomment-629806195.
401 preConfigure =
402 lib.optionalString cudaSupport ''
403 export TORCH_CUDA_ARCH_LIST="${gpuTargetString}"
404 export CUPTI_INCLUDE_DIR=${lib.getDev cudaPackages.cuda_cupti}/include
405 export CUPTI_LIBRARY_DIR=${lib.getLib cudaPackages.cuda_cupti}/lib
406 ''
407 + lib.optionalString (cudaSupport && cudaPackages ? cudnn) ''
408 export CUDNN_INCLUDE_DIR=${lib.getLib cudnn}/include
409 export CUDNN_LIB_DIR=${lib.getLib cudnn}/lib
410 ''
411 + lib.optionalString rocmSupport ''
412 export ROCM_PATH=${rocmtoolkit_joined}
413 export ROCM_SOURCE_DIR=${rocmtoolkit_joined}
414 export PYTORCH_ROCM_ARCH="${gpuTargetString}"
415 export CMAKE_CXX_FLAGS="-I${rocmtoolkit_joined}/include -I${rocmtoolkit_joined}/include/rocblas"
416 python tools/amd_build/build_amd.py
417 '';
418
419 # Use pytorch's custom configurations
420 dontUseCmakeConfigure = true;
421
422 # causes possible redefinition of _FORTIFY_SOURCE
423 hardeningDisable = [ "fortify3" ];
424
425 BUILD_NAMEDTENSOR = setBool true;
426 BUILD_DOCS = setBool buildDocs;
427
428 # We only do an imports check, so do not build tests either.
429 BUILD_TEST = setBool false;
430
431 # ninja hook doesn't automatically turn on ninja
432 # because pytorch setup.py is responsible for this
433 CMAKE_GENERATOR = "Ninja";
434
435 # Unlike MKL, oneDNN (née MKLDNN) is FOSS, so we enable support for
436 # it by default. PyTorch currently uses its own vendored version
437 # of oneDNN through Intel iDeep.
438 USE_MKLDNN = setBool mklDnnSupport;
439 USE_MKLDNN_CBLAS = setBool mklDnnSupport;
440
441 # Avoid using pybind11 from git submodule
442 # Also avoids pytorch exporting the headers of pybind11
443 USE_SYSTEM_PYBIND11 = true;
444
445 # Multicore CPU convnet support
446 USE_NNPACK = 1;
447
448 # Explicitly enable MPS for Darwin
449 USE_MPS = setBool stdenv.hostPlatform.isDarwin;
450
451 # building torch.distributed on Darwin is disabled by default
452 # https://pytorch.org/docs/stable/distributed.html#torch.distributed.is_available
453 USE_DISTRIBUTED = setBool true;
454
455 cmakeFlags = [
456 (lib.cmakeFeature "PYTHON_SIX_SOURCE_DIR" "${six.src}")
457 # (lib.cmakeBool "CMAKE_FIND_DEBUG_MODE" true)
458 (lib.cmakeFeature "CUDAToolkit_VERSION" cudaPackages.cudaMajorMinorVersion)
459 ]
460 ++ lib.optionals cudaSupport [
461 # Unbreaks version discovery in enable_language(CUDA) when wrapping nvcc with ccache
462 # Cf. https://gitlab.kitware.com/cmake/cmake/-/issues/26363
463 (lib.cmakeFeature "CMAKE_CUDA_COMPILER_TOOLKIT_VERSION" cudaPackages.cudaMajorMinorVersion)
464 ];
465
466 preBuild = ''
467 export MAX_JOBS=$NIX_BUILD_CORES
468 ${python.pythonOnBuildForHost.interpreter} setup.py build --cmake-only
469 ${cmake}/bin/cmake build
470 '';
471
472 preFixup = ''
473 function join_by { local IFS="$1"; shift; echo "$*"; }
474 function strip2 {
475 IFS=':'
476 read -ra RP <<< $(patchelf --print-rpath $1)
477 IFS=' '
478 RP_NEW=$(join_by : ''${RP[@]:2})
479 patchelf --set-rpath \$ORIGIN:''${RP_NEW} "$1"
480 }
481 for f in $(find ''${out} -name 'libcaffe2*.so')
482 do
483 strip2 $f
484 done
485 '';
486
487 # Override the (weirdly) wrong version set by default. See
488 # https://github.com/NixOS/nixpkgs/pull/52437#issuecomment-449718038
489 # https://github.com/pytorch/pytorch/blob/v1.0.0/setup.py#L267
490 PYTORCH_BUILD_VERSION = version;
491 PYTORCH_BUILD_NUMBER = 0;
492
493 # In-tree builds of NCCL are not supported.
494 # Use NCCL when cudaSupport is enabled and nccl is available.
495 USE_NCCL = setBool useSystemNccl;
496 USE_SYSTEM_NCCL = USE_NCCL;
497 USE_STATIC_NCCL = USE_NCCL;
498
499 # Set the correct Python library path, broken since
500 # https://github.com/pytorch/pytorch/commit/3d617333e
501 PYTHON_LIB_REL_PATH = "${placeholder "out"}/${python.sitePackages}";
502
503 env = {
504 # disable warnings as errors as they break the build on every compiler
505 # bump, among other things.
506 # Also of interest: pytorch ignores CXXFLAGS uses CFLAGS for both C and C++:
507 # https://github.com/pytorch/pytorch/blob/v1.11.0/setup.py#L17
508 NIX_CFLAGS_COMPILE = toString (
509 [
510 "-Wno-error"
511 ]
512 # fix build aarch64-linux build failure with GCC14
513 ++ lib.optionals (stdenv.hostPlatform.isLinux && stdenv.hostPlatform.isAarch64) [
514 "-Wno-error=incompatible-pointer-types"
515 ]
516 );
517 USE_VULKAN = setBool vulkanSupport;
518 }
519 // lib.optionalAttrs vulkanSupport {
520 VULKAN_SDK = shaderc.bin;
521 }
522 // lib.optionalAttrs rocmSupport {
523 AOTRITON_INSTALLED_PREFIX = "${rocmPackages.aotriton}";
524 };
525
526 nativeBuildInputs = [
527 cmake
528 which
529 ninja
530 pybind11
531 pkg-config
532 removeReferencesTo
533 ]
534 ++ lib.optionals cudaSupport (
535 with cudaPackages;
536 [
537 autoAddDriverRunpath
538 cuda_nvcc
539 ]
540 )
541 ++ lib.optionals isCudaJetson [ cudaPackages.autoAddCudaCompatRunpath ]
542 ++ lib.optionals rocmSupport [ rocmtoolkit_joined ];
543
544 buildInputs = [
545 blas
546 blas.provider
547 ]
548 # Including openmp leads to two copies being used on ARM, which segfaults.
549 # https://github.com/pytorch/pytorch/issues/149201#issuecomment-2776842320
550 ++ lib.optionals (stdenv.cc.isClang && !stdenv.hostPlatform.isAarch64) [ llvmPackages.openmp ]
551 ++ lib.optionals cudaSupport (
552 with cudaPackages;
553 [
554 cuda_cccl # <thrust/*>
555 cuda_cudart # cuda_runtime.h and libraries
556 cuda_cupti # For kineto
557 cuda_nvcc # crt/host_config.h; even though we include this in nativeBuildInputs, it's needed here too
558 cuda_nvml_dev # <nvml.h>
559 cuda_nvrtc
560 cuda_nvtx # -llibNVToolsExt
561 cusparselt
562 libcublas
563 libcufft
564 libcufile
565 libcurand
566 libcusolver
567 libcusparse
568 ]
569 ++ lists.optionals (cudaPackages ? cudnn) [ cudnn ]
570 ++ lists.optionals useSystemNccl [
571 # Some platforms do not support NCCL (i.e., Jetson)
572 nccl # Provides nccl.h AND a static copy of NCCL!
573 ]
574 ++ [
575 cuda_profiler_api # <cuda_profiler_api.h>
576 ]
577 )
578 ++ lib.optionals rocmSupport [ rocmPackages.llvm.openmp ]
579 ++ lib.optionals (cudaSupport || rocmSupport) [ effectiveMagma ]
580 ++ lib.optionals stdenv.hostPlatform.isLinux [ numactl ]
581 ++ lib.optionals stdenv.hostPlatform.isDarwin [
582 apple-sdk_13
583 ]
584 ++ lib.optionals tritonSupport [ _tritonEffective ]
585 ++ lib.optionals MPISupport [ mpi ]
586 ++ lib.optionals rocmSupport [
587 rocmtoolkit_joined
588 rocmPackages.clr # Added separately so setup hook applies
589 ];
590
591 pythonRelaxDeps = [
592 "sympy"
593 ];
594 dependencies = [
595 astunparse
596 expecttest
597 filelock
598 fsspec
599 hypothesis
600 jinja2
601 networkx
602 ninja
603 packaging
604 psutil
605 pyyaml
606 requests
607 sympy
608 types-dataclasses
609 typing-extensions
610
611 # the following are required for tensorboard support
612 pillow
613 six
614 tensorboard
615 protobuf
616
617 # torch/csrc requires `pybind11` at runtime
618 pybind11
619 ]
620 ++ lib.optionals tritonSupport [ _tritonEffective ]
621 ++ lib.optionals vulkanSupport [
622 vulkan-headers
623 vulkan-loader
624 ];
625
626 propagatedCxxBuildInputs =
627 [ ] ++ lib.optionals MPISupport [ mpi ] ++ lib.optionals rocmSupport [ rocmtoolkit_joined ];
628
629 # Tests take a long time and may be flaky, so just sanity-check imports
630 doCheck = false;
631
632 pythonImportsCheck = [ "torch" ];
633
634 nativeCheckInputs = [
635 hypothesis
636 ninja
637 psutil
638 ];
639
640 checkPhase =
641 with lib.versions;
642 with lib.strings;
643 concatStringsSep " " [
644 "runHook preCheck"
645 "${python.interpreter} test/run_test.py"
646 "--exclude"
647 (concatStringsSep " " [
648 "utils" # utils requires git, which is not allowed in the check phase
649
650 # "dataloader" # psutils correctly finds and triggers multiprocessing, but is too sandboxed to run -- resulting in numerous errors
651 # ^^^^^^^^^^^^ NOTE: while test_dataloader does return errors, these are acceptable errors and do not interfere with the build
652
653 # tensorboard has acceptable failures for pytorch 1.3.x due to dependencies on tensorboard-plugins
654 (optionalString (majorMinor version == "1.3") "tensorboard")
655 ])
656 "runHook postCheck"
657 ];
658
659 pythonRemoveDeps = [
660 # In our dist-info the name is just "triton"
661 "pytorch-triton-rocm"
662 ];
663
664 postInstall = ''
665 find "$out/${python.sitePackages}/torch/include" "$out/${python.sitePackages}/torch/lib" -type f -exec remove-references-to -t ${stdenv.cc} '{}' +
666
667 mkdir $dev
668
669 # CppExtension requires that include files are packaged with the main
670 # python library output; which is why they are copied here.
671 cp -r $out/${python.sitePackages}/torch/include $dev/include
672
673 # Cmake files under /share are different and can be safely moved. This
674 # avoids unnecessary closure blow-up due to apple sdk references when
675 # USE_DISTRIBUTED is enabled.
676 mv $out/${python.sitePackages}/torch/share $dev/share
677
678 # Fix up library paths for split outputs
679 substituteInPlace \
680 $dev/share/cmake/Torch/TorchConfig.cmake \
681 --replace-fail \''${TORCH_INSTALL_PREFIX}/lib "$lib/lib"
682
683 substituteInPlace \
684 $dev/share/cmake/Caffe2/Caffe2Targets-release.cmake \
685 --replace-fail \''${_IMPORT_PREFIX}/lib "$lib/lib"
686
687 mkdir $lib
688 mv $out/${python.sitePackages}/torch/lib $lib/lib
689 ln -s $lib/lib $out/${python.sitePackages}/torch/lib
690 ''
691 + lib.optionalString rocmSupport ''
692 substituteInPlace $dev/share/cmake/Tensorpipe/TensorpipeTargets-release.cmake \
693 --replace-fail "\''${_IMPORT_PREFIX}/lib64" "$lib/lib"
694
695 substituteInPlace $dev/share/cmake/ATen/ATenConfig.cmake \
696 --replace-fail "/build/${src.name}/torch/include" "$dev/include"
697 '';
698
699 postFixup = ''
700 mkdir -p "$cxxdev/nix-support"
701 printWords "''${propagatedCxxBuildInputs[@]}" >> "$cxxdev/nix-support/propagated-build-inputs"
702 ''
703 + lib.optionalString stdenv.hostPlatform.isDarwin ''
704 for f in $(ls $lib/lib/*.dylib); do
705 install_name_tool -id $lib/lib/$(basename $f) $f || true
706 done
707
708 install_name_tool -change @rpath/libshm.dylib $lib/lib/libshm.dylib $lib/lib/libtorch_python.dylib
709 install_name_tool -change @rpath/libtorch.dylib $lib/lib/libtorch.dylib $lib/lib/libtorch_python.dylib
710 install_name_tool -change @rpath/libc10.dylib $lib/lib/libc10.dylib $lib/lib/libtorch_python.dylib
711
712 install_name_tool -change @rpath/libc10.dylib $lib/lib/libc10.dylib $lib/lib/libtorch.dylib
713
714 install_name_tool -change @rpath/libtorch.dylib $lib/lib/libtorch.dylib $lib/lib/libshm.dylib
715 install_name_tool -change @rpath/libc10.dylib $lib/lib/libc10.dylib $lib/lib/libshm.dylib
716 '';
717
718 # See https://github.com/NixOS/nixpkgs/issues/296179
719 #
720 # This is a quick hack to add `libnvrtc` to the runpath so that torch can find
721 # it when it is needed at runtime.
722 extraRunpaths = lib.optionals cudaSupport [ "${lib.getLib cudaPackages.cuda_nvrtc}/lib" ];
723 postPhases = lib.optionals stdenv.hostPlatform.isLinux [ "postPatchelfPhase" ];
724 postPatchelfPhase = ''
725 while IFS= read -r -d $'\0' elf ; do
726 for extra in $extraRunpaths ; do
727 echo patchelf "$elf" --add-rpath "$extra" >&2
728 patchelf "$elf" --add-rpath "$extra"
729 done
730 done < <(
731 find "''${!outputLib}" "$out" -type f -iname '*.so' -print0
732 )
733 '';
734
735 # Builds in 2+h with 2 cores, and ~15m with a big-parallel builder.
736 requiredSystemFeatures = [ "big-parallel" ];
737
738 passthru = {
739 inherit
740 cudaSupport
741 cudaPackages
742 rocmSupport
743 rocmPackages
744 unroll-src
745 ;
746 cudaCapabilities = if cudaSupport then supportedCudaCapabilities else [ ];
747 # At least for 1.10.2 `torch.fft` is unavailable unless BLAS provider is MKL. This attribute allows for easy detection of its availability.
748 blasProvider = blas.provider;
749 # To help debug when a package is broken due to CUDA support
750 inherit brokenConditions;
751 tests = callPackage ../tests { };
752 };
753
754 meta = {
755 changelog = "https://github.com/pytorch/pytorch/releases/tag/v${version}";
756 # keep PyTorch in the description so the package can be found under that name on search.nixos.org
757 description = "PyTorch: Tensors and Dynamic neural networks in Python with strong GPU acceleration";
758 homepage = "https://pytorch.org/";
759 license = lib.licenses.bsd3;
760 maintainers = with lib.maintainers; [
761 GaetanLepage
762 teh
763 thoughtpolice
764 tscholak
765 ]; # tscholak esp. for darwin-related builds
766 platforms =
767 lib.platforms.linux ++ lib.optionals (!cudaSupport && !rocmSupport) lib.platforms.darwin;
768 broken = builtins.any trivial.id (builtins.attrValues brokenConditions);
769 };
770}