1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5 ninja,
6 setuptools,
7 torch,
8 cudaPackages,
9 rocmPackages,
10 config,
11 cudaSupport ? config.cudaSupport,
12 which,
13}:
14
15buildPythonPackage rec {
16 pname = "causal-conv1d";
17 version = "1.5.2";
18 pyproject = true;
19
20 src = fetchFromGitHub {
21 owner = "Dao-AILab";
22 repo = "causal-conv1d";
23 tag = "v${version}";
24 hash = "sha256-B2I5QiJl0p5d1BeQcMbJBAYUb10HzqFd88QMM8Rerm0=";
25 };
26
27 build-system = [
28 ninja
29 setuptools
30 torch
31 ];
32
33 nativeBuildInputs = [ which ];
34
35 buildInputs = (
36 lib.optionals cudaSupport (
37 with cudaPackages;
38 [
39 cuda_cudart # cuda_runtime.h, -lcudart
40 cuda_cccl
41 libcusparse # cusparse.h
42 libcusolver # cusolverDn.h
43 cuda_nvcc
44 libcublas
45 ]
46 )
47 );
48
49 dependencies = [
50 torch
51 ];
52
53 # pytest tests not enabled due to nvidia GPU dependency
54 pythonImportsCheck = [ "causal_conv1d" ];
55
56 env = {
57 CAUSAL_CONV1D_FORCE_BUILD = "TRUE";
58 }
59 // lib.optionalAttrs cudaSupport { CUDA_HOME = "${lib.getDev cudaPackages.cuda_nvcc}"; };
60
61 meta = with lib; {
62 description = "Causal depthwise conv1d in CUDA with a PyTorch interface";
63 homepage = "https://github.com/Dao-AILab/causal-conv1d";
64 license = licenses.bsd3;
65 maintainers = with maintainers; [ cfhammill ];
66 # The package requires CUDA or ROCm, the ROCm build hasn't
67 # been completed or tested, so broken if not using cuda.
68 broken = !cudaSupport;
69 };
70}