1{ 2 lib, 3 stdenv, 4 torch, 5 apple-sdk_13, 6 buildPythonPackage, 7 darwinMinVersionHook, 8 fetchFromGitHub, 9 10 # nativeBuildInputs 11 libpng, 12 ninja, 13 which, 14 15 # buildInputs 16 libjpeg_turbo, 17 18 # dependencies 19 numpy, 20 pillow, 21 scipy, 22 23 # tests 24 pytest, 25 writableTmpDirAsHomeHook, 26}: 27 28let 29 inherit (torch) cudaCapabilities cudaPackages cudaSupport; 30 31 pname = "torchvision"; 32 version = "0.23.0"; 33in 34buildPythonPackage { 35 format = "setuptools"; 36 inherit pname version; 37 38 stdenv = torch.stdenv; 39 40 src = fetchFromGitHub { 41 owner = "pytorch"; 42 repo = "vision"; 43 tag = "v${version}"; 44 hash = "sha256-BfGTq9BsmO5TtQrDED35aaT9quleZ9rcr/81ShfvCbQ="; 45 }; 46 47 nativeBuildInputs = [ 48 libpng 49 ninja 50 which 51 ] 52 ++ lib.optionals cudaSupport [ cudaPackages.cuda_nvcc ]; 53 54 buildInputs = [ 55 libjpeg_turbo 56 libpng 57 torch.cxxdev 58 ] 59 ++ lib.optionals stdenv.hostPlatform.isDarwin [ 60 # This should match the SDK used by `torch` above 61 apple-sdk_13 62 63 # error: unknown type name 'MPSGraphCompilationDescriptor'; did you mean 'MPSGraphExecutionDescriptor'? 64 # https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraphcompilationdescriptor/ 65 (darwinMinVersionHook "12.0") 66 ]; 67 68 dependencies = [ 69 numpy 70 pillow 71 torch 72 scipy 73 ]; 74 75 env = { 76 TORCHVISION_INCLUDE = "${libjpeg_turbo.dev}/include/"; 77 TORCHVISION_LIBRARY = "${libjpeg_turbo}/lib/"; 78 } 79 // lib.optionalAttrs cudaSupport { 80 TORCH_CUDA_ARCH_LIST = "${lib.concatStringsSep ";" cudaCapabilities}"; 81 FORCE_CUDA = 1; 82 }; 83 84 # tests download big datasets, models, require internet connection, etc. 85 doCheck = false; 86 87 pythonImportsCheck = [ "torchvision" ]; 88 89 nativeCheckInputs = [ 90 pytest 91 writableTmpDirAsHomeHook 92 ]; 93 94 checkPhase = '' 95 py.test test --ignore=test/test_datasets_download.py 96 ''; 97 98 meta = { 99 description = "PyTorch vision library"; 100 homepage = "https://pytorch.org/"; 101 changelog = "https://github.com/pytorch/vision/releases/tag/v${version}"; 102 license = lib.licenses.bsd3; 103 platforms = with lib.platforms; linux ++ lib.optionals (!cudaSupport) darwin; 104 maintainers = with lib.maintainers; [ GaetanLepage ]; 105 }; 106}