at master 1.6 kB view raw
1{ 2 lib, 3 buildPythonPackage, 4 fetchFromGitHub, 5 setuptools, 6 wheel, 7 torch, 8 iopath, 9 cudaPackages, 10 config, 11 cudaSupport ? config.cudaSupport, 12}: 13 14assert cudaSupport -> torch.cudaSupport; 15 16buildPythonPackage rec { 17 pname = "pytorch3d"; 18 version = "0.7.8"; 19 pyproject = true; 20 21 src = fetchFromGitHub { 22 owner = "facebookresearch"; 23 repo = "pytorch3d"; 24 rev = "V${version}"; 25 hash = "sha256-DEEWWfjwjuXGc0WQInDTmtnWSIDUifyByxdg7hpdHlo="; 26 }; 27 28 nativeBuildInputs = lib.optionals cudaSupport [ cudaPackages.cuda_nvcc ]; 29 build-system = [ 30 setuptools 31 wheel 32 ]; 33 dependencies = [ 34 torch 35 iopath 36 ]; 37 buildInputs = [ (lib.getOutput "cxxdev" torch) ]; 38 39 env = { 40 FORCE_CUDA = cudaSupport; 41 } 42 // lib.optionalAttrs cudaSupport { 43 TORCH_CUDA_ARCH_LIST = "${lib.concatStringsSep ";" torch.cudaCapabilities}"; 44 }; 45 46 pythonImportsCheck = [ "pytorch3d" ]; 47 48 passthru.tests.rotations-cuda = 49 cudaPackages.writeGpuTestPython { libraries = ps: [ ps.pytorch3d ]; } 50 '' 51 import pytorch3d.transforms as p3dt 52 53 M = p3dt.random_rotations(n=10, device="cuda") 54 assert "cuda" in M.device.type 55 angles = p3dt.matrix_to_euler_angles(M, "XYZ") 56 assert "cuda" in angles.device.type 57 assert angles.shape == (10, 3), angles.shape 58 print(angles) 59 ''; 60 61 meta = { 62 description = "FAIR's library of reusable components for deep learning with 3D data"; 63 homepage = "https://github.com/facebookresearch/pytorch3d"; 64 license = lib.licenses.bsd3; 65 maintainers = with lib.maintainers; [ 66 pbsds 67 SomeoneSerge 68 ]; 69 }; 70}