1{ 2 lib, 3 buildPythonPackage, 4 fetchFromGitHub, 5 ninja, 6 setuptools, 7 torch, 8 cudaPackages, 9 rocmPackages, 10 config, 11 cudaSupport ? config.cudaSupport, 12 which, 13}: 14 15buildPythonPackage rec { 16 pname = "causal-conv1d"; 17 version = "1.5.2"; 18 pyproject = true; 19 20 src = fetchFromGitHub { 21 owner = "Dao-AILab"; 22 repo = "causal-conv1d"; 23 tag = "v${version}"; 24 hash = "sha256-B2I5QiJl0p5d1BeQcMbJBAYUb10HzqFd88QMM8Rerm0="; 25 }; 26 27 build-system = [ 28 ninja 29 setuptools 30 torch 31 ]; 32 33 nativeBuildInputs = [ which ]; 34 35 buildInputs = ( 36 lib.optionals cudaSupport ( 37 with cudaPackages; 38 [ 39 cuda_cudart # cuda_runtime.h, -lcudart 40 cuda_cccl 41 libcusparse # cusparse.h 42 libcusolver # cusolverDn.h 43 cuda_nvcc 44 libcublas 45 ] 46 ) 47 ); 48 49 dependencies = [ 50 torch 51 ]; 52 53 # pytest tests not enabled due to nvidia GPU dependency 54 pythonImportsCheck = [ "causal_conv1d" ]; 55 56 env = { 57 CAUSAL_CONV1D_FORCE_BUILD = "TRUE"; 58 } 59 // lib.optionalAttrs cudaSupport { CUDA_HOME = "${lib.getDev cudaPackages.cuda_nvcc}"; }; 60 61 meta = with lib; { 62 description = "Causal depthwise conv1d in CUDA with a PyTorch interface"; 63 homepage = "https://github.com/Dao-AILab/causal-conv1d"; 64 license = licenses.bsd3; 65 maintainers = with maintainers; [ cfhammill ]; 66 # The package requires CUDA or ROCm, the ROCm build hasn't 67 # been completed or tested, so broken if not using cuda. 68 broken = !cudaSupport; 69 }; 70}