1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5 causal-conv1d,
6 einops,
7 ninja,
8 setuptools,
9 torch,
10 transformers,
11 triton,
12 cudaPackages,
13 rocmPackages,
14 config,
15 cudaSupport ? config.cudaSupport,
16 which,
17}:
18
19buildPythonPackage rec {
20 pname = "mamba";
21 version = "2.2.2";
22 pyproject = true;
23
24 src = fetchFromGitHub {
25 owner = "state-spaces";
26 repo = "mamba";
27 tag = "v${version}";
28 hash = "sha256-R702JjM3AGk7upN7GkNK8u1q4ekMK9fYQkpO6Re45Ng=";
29 };
30
31 build-system = [
32 ninja
33 setuptools
34 torch
35 ];
36
37 nativeBuildInputs = [ which ];
38
39 buildInputs = (
40 lib.optionals cudaSupport (
41 with cudaPackages;
42 [
43 cuda_cudart # cuda_runtime.h, -lcudart
44 cuda_cccl
45 libcusparse # cusparse.h
46 libcusolver # cusolverDn.h
47 cuda_nvcc
48 libcublas
49 ]
50 )
51 );
52
53 dependencies = [
54 causal-conv1d
55 einops
56 torch
57 transformers
58 triton
59 ];
60
61 env = {
62 MAMBA_FORCE_BUILD = "TRUE";
63 }
64 // lib.optionalAttrs cudaSupport { CUDA_HOME = "${lib.getDev cudaPackages.cuda_nvcc}"; };
65
66 # pytest tests not enabled due to nvidia GPU dependency
67 pythonImportsCheck = [ "mamba_ssm" ];
68
69 meta = with lib; {
70 description = "Linear-Time Sequence Modeling with Selective State Spaces";
71 homepage = "https://github.com/state-spaces/mamba";
72 license = licenses.asl20;
73 maintainers = with maintainers; [ cfhammill ];
74 # The package requires CUDA or ROCm, the ROCm build hasn't
75 # been completed or tested, so broken if not using cuda.
76 broken = !cudaSupport;
77 };
78}