at master 1.6 kB view raw
1{ 2 lib, 3 buildPythonPackage, 4 fetchFromGitHub, 5 causal-conv1d, 6 einops, 7 ninja, 8 setuptools, 9 torch, 10 transformers, 11 triton, 12 cudaPackages, 13 rocmPackages, 14 config, 15 cudaSupport ? config.cudaSupport, 16 which, 17}: 18 19buildPythonPackage rec { 20 pname = "mamba"; 21 version = "2.2.2"; 22 pyproject = true; 23 24 src = fetchFromGitHub { 25 owner = "state-spaces"; 26 repo = "mamba"; 27 tag = "v${version}"; 28 hash = "sha256-R702JjM3AGk7upN7GkNK8u1q4ekMK9fYQkpO6Re45Ng="; 29 }; 30 31 build-system = [ 32 ninja 33 setuptools 34 torch 35 ]; 36 37 nativeBuildInputs = [ which ]; 38 39 buildInputs = ( 40 lib.optionals cudaSupport ( 41 with cudaPackages; 42 [ 43 cuda_cudart # cuda_runtime.h, -lcudart 44 cuda_cccl 45 libcusparse # cusparse.h 46 libcusolver # cusolverDn.h 47 cuda_nvcc 48 libcublas 49 ] 50 ) 51 ); 52 53 dependencies = [ 54 causal-conv1d 55 einops 56 torch 57 transformers 58 triton 59 ]; 60 61 env = { 62 MAMBA_FORCE_BUILD = "TRUE"; 63 } 64 // lib.optionalAttrs cudaSupport { CUDA_HOME = "${lib.getDev cudaPackages.cuda_nvcc}"; }; 65 66 # pytest tests not enabled due to nvidia GPU dependency 67 pythonImportsCheck = [ "mamba_ssm" ]; 68 69 meta = with lib; { 70 description = "Linear-Time Sequence Modeling with Selective State Spaces"; 71 homepage = "https://github.com/state-spaces/mamba"; 72 license = licenses.asl20; 73 maintainers = with maintainers; [ cfhammill ]; 74 # The package requires CUDA or ROCm, the ROCm build hasn't 75 # been completed or tested, so broken if not using cuda. 76 broken = !cudaSupport; 77 }; 78}