1{
2 lib,
3 stdenv,
4 torch,
5 apple-sdk_13,
6 buildPythonPackage,
7 darwinMinVersionHook,
8 fetchFromGitHub,
9
10 # nativeBuildInputs
11 libpng,
12 ninja,
13 which,
14
15 # buildInputs
16 libjpeg_turbo,
17
18 # dependencies
19 numpy,
20 pillow,
21 scipy,
22
23 # tests
24 pytest,
25 writableTmpDirAsHomeHook,
26}:
27
28let
29 inherit (torch) cudaCapabilities cudaPackages cudaSupport;
30
31 pname = "torchvision";
32 version = "0.23.0";
33in
34buildPythonPackage {
35 format = "setuptools";
36 inherit pname version;
37
38 stdenv = torch.stdenv;
39
40 src = fetchFromGitHub {
41 owner = "pytorch";
42 repo = "vision";
43 tag = "v${version}";
44 hash = "sha256-BfGTq9BsmO5TtQrDED35aaT9quleZ9rcr/81ShfvCbQ=";
45 };
46
47 nativeBuildInputs = [
48 libpng
49 ninja
50 which
51 ]
52 ++ lib.optionals cudaSupport [ cudaPackages.cuda_nvcc ];
53
54 buildInputs = [
55 libjpeg_turbo
56 libpng
57 torch.cxxdev
58 ]
59 ++ lib.optionals stdenv.hostPlatform.isDarwin [
60 # This should match the SDK used by `torch` above
61 apple-sdk_13
62
63 # error: unknown type name 'MPSGraphCompilationDescriptor'; did you mean 'MPSGraphExecutionDescriptor'?
64 # https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraphcompilationdescriptor/
65 (darwinMinVersionHook "12.0")
66 ];
67
68 dependencies = [
69 numpy
70 pillow
71 torch
72 scipy
73 ];
74
75 env = {
76 TORCHVISION_INCLUDE = "${libjpeg_turbo.dev}/include/";
77 TORCHVISION_LIBRARY = "${libjpeg_turbo}/lib/";
78 }
79 // lib.optionalAttrs cudaSupport {
80 TORCH_CUDA_ARCH_LIST = "${lib.concatStringsSep ";" cudaCapabilities}";
81 FORCE_CUDA = 1;
82 };
83
84 # tests download big datasets, models, require internet connection, etc.
85 doCheck = false;
86
87 pythonImportsCheck = [ "torchvision" ];
88
89 nativeCheckInputs = [
90 pytest
91 writableTmpDirAsHomeHook
92 ];
93
94 checkPhase = ''
95 py.test test --ignore=test/test_datasets_download.py
96 '';
97
98 meta = {
99 description = "PyTorch vision library";
100 homepage = "https://pytorch.org/";
101 changelog = "https://github.com/pytorch/vision/releases/tag/v${version}";
102 license = lib.licenses.bsd3;
103 platforms = with lib.platforms; linux ++ lib.optionals (!cudaSupport) darwin;
104 maintainers = with lib.maintainers; [ GaetanLepage ];
105 };
106}