1{
2 lib,
3 stdenv,
4 cmake,
5 libtorch-bin,
6 linkFarm,
7 symlinkJoin,
8
9 cudaSupport,
10 cudaPackages ? { },
11}:
12let
13 inherit (cudaPackages) cudatoolkit cudnn;
14
15 cudatoolkit_joined = symlinkJoin {
16 name = "${cudatoolkit.name}-unsplit";
17 paths = [
18 cudatoolkit.out
19 cudatoolkit.lib
20 ];
21 };
22
23 # We do not have access to /run/opengl-driver/lib in the sandbox,
24 # so use a stub instead.
25 cudaStub = linkFarm "cuda-stub" [
26 {
27 name = "libcuda.so.1";
28 path = "${cudatoolkit}/lib/stubs/libcuda.so";
29 }
30 ];
31
32in
33stdenv.mkDerivation {
34 pname = "libtorch-test";
35 version = libtorch-bin.version;
36
37 src = lib.fileset.toSource {
38 root = ./.;
39 fileset = lib.fileset.unions [
40 ./CMakeLists.txt
41 ./test.cpp
42 ];
43 };
44
45 nativeBuildInputs = [ cmake ];
46
47 buildInputs = [ libtorch-bin ] ++ lib.optionals cudaSupport [ cudnn ];
48
49 cmakeFlags = lib.optionals cudaSupport [ "-DCUDA_TOOLKIT_ROOT_DIR=${cudatoolkit_joined}" ];
50
51 doCheck = true;
52
53 installPhase = ''
54 touch $out
55 '';
56
57 checkPhase =
58 lib.optionalString cudaSupport ''
59 LD_LIBRARY_PATH=${cudaStub}''${LD_LIBRARY_PATH:+:}$LD_LIBRARY_PATH \
60 ''
61 + ''
62 ./test
63 '';
64}