1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5 setuptools,
6 wheel,
7 torch,
8 iopath,
9 cudaPackages,
10 config,
11 cudaSupport ? config.cudaSupport,
12}:
13
14assert cudaSupport -> torch.cudaSupport;
15
16buildPythonPackage rec {
17 pname = "pytorch3d";
18 version = "0.7.8";
19 pyproject = true;
20
21 src = fetchFromGitHub {
22 owner = "facebookresearch";
23 repo = "pytorch3d";
24 rev = "V${version}";
25 hash = "sha256-DEEWWfjwjuXGc0WQInDTmtnWSIDUifyByxdg7hpdHlo=";
26 };
27
28 nativeBuildInputs = lib.optionals cudaSupport [ cudaPackages.cuda_nvcc ];
29 build-system = [
30 setuptools
31 wheel
32 ];
33 dependencies = [
34 torch
35 iopath
36 ];
37 buildInputs = [ (lib.getOutput "cxxdev" torch) ];
38
39 env = {
40 FORCE_CUDA = cudaSupport;
41 }
42 // lib.optionalAttrs cudaSupport {
43 TORCH_CUDA_ARCH_LIST = "${lib.concatStringsSep ";" torch.cudaCapabilities}";
44 };
45
46 pythonImportsCheck = [ "pytorch3d" ];
47
48 passthru.tests.rotations-cuda =
49 cudaPackages.writeGpuTestPython { libraries = ps: [ ps.pytorch3d ]; }
50 ''
51 import pytorch3d.transforms as p3dt
52
53 M = p3dt.random_rotations(n=10, device="cuda")
54 assert "cuda" in M.device.type
55 angles = p3dt.matrix_to_euler_angles(M, "XYZ")
56 assert "cuda" in angles.device.type
57 assert angles.shape == (10, 3), angles.shape
58 print(angles)
59 '';
60
61 meta = {
62 description = "FAIR's library of reusable components for deep learning with 3D data";
63 homepage = "https://github.com/facebookresearch/pytorch3d";
64 license = lib.licenses.bsd3;
65 maintainers = with lib.maintainers; [
66 pbsds
67 SomeoneSerge
68 ];
69 };
70}