1{
2 lib,
3 torch,
4 symlinkJoin,
5 buildPythonPackage,
6 fetchFromGitHub,
7 cmake,
8
9 # build-system
10 scikit-build-core,
11 setuptools,
12
13 # dependencies
14 scipy,
15}:
16
17let
18 pname = "bitsandbytes";
19 version = "0.47.0";
20
21 inherit (torch) cudaPackages cudaSupport;
22 inherit (cudaPackages) cudaMajorMinorVersion;
23
24 cudaMajorMinorVersionString = lib.replaceStrings [ "." ] [ "" ] cudaMajorMinorVersion;
25
26 # NOTE: torchvision doesn't use cudnn; torch does!
27 # For this reason it is not included.
28 cuda-common-redist = with cudaPackages; [
29 (lib.getDev cuda_cccl) # <thrust/*>
30 (lib.getDev libcublas) # cublas_v2.h
31 (lib.getLib libcublas)
32 libcurand
33 libcusolver # cusolverDn.h
34 (lib.getDev libcusparse) # cusparse.h
35 (lib.getLib libcusparse) # cusparse.h
36 (lib.getDev cuda_cudart) # cuda_runtime.h cuda_runtime_api.h
37 ];
38
39 cuda-native-redist = symlinkJoin {
40 name = "cuda-native-redist-${cudaMajorMinorVersion}";
41 paths =
42 with cudaPackages;
43 [
44 (lib.getDev cuda_cudart) # cuda_runtime.h cuda_runtime_api.h
45 (lib.getLib cuda_cudart)
46 (lib.getStatic cuda_cudart)
47 cuda_nvcc
48 ]
49 ++ cuda-common-redist;
50 };
51
52 cuda-redist = symlinkJoin {
53 name = "cuda-redist-${cudaMajorMinorVersion}";
54 paths = cuda-common-redist;
55 };
56in
57buildPythonPackage {
58 inherit pname version;
59 pyproject = true;
60
61 src = fetchFromGitHub {
62 owner = "bitsandbytes-foundation";
63 repo = "bitsandbytes";
64 tag = version;
65 hash = "sha256-iUAeiNbPa3Q5jJ4lK2G0WvTKuipb0zO1mNe+wcRdnqs=";
66 };
67
68 # By default, which library is loaded depends on the result of `torch.cuda.is_available()`.
69 # When `cudaSupport` is enabled, bypass this check and load the cuda library unconditionally.
70 # Indeed, in this case, only `libbitsandbytes_cuda124.so` is built. `libbitsandbytes_cpu.so` is not.
71 # Also, hardcode the path to the previously built library instead of relying on
72 # `get_cuda_bnb_library_path(cuda_specs)` which relies on `torch.cuda` too.
73 #
74 # WARNING: The cuda library is currently named `libbitsandbytes_cudaxxy` for cuda version `xx.y`.
75 # This upstream convention could change at some point and thus break the following patch.
76 postPatch = lib.optionalString cudaSupport ''
77 substituteInPlace bitsandbytes/cextension.py \
78 --replace-fail "if cuda_specs:" "if True:" \
79 --replace-fail \
80 "cuda_binary_path = get_cuda_bnb_library_path(cuda_specs)" \
81 "cuda_binary_path = PACKAGE_DIR / 'libbitsandbytes_cuda${cudaMajorMinorVersionString}.so'"
82 '';
83
84 nativeBuildInputs = [
85 cmake
86 ]
87 ++ lib.optionals cudaSupport [
88 cudaPackages.cuda_nvcc
89 ];
90
91 build-system = [
92 scikit-build-core
93 setuptools
94 ];
95
96 buildInputs = lib.optionals cudaSupport [ cuda-redist ];
97
98 cmakeFlags = [
99 (lib.cmakeFeature "COMPUTE_BACKEND" (if cudaSupport then "cuda" else "cpu"))
100 ];
101 CUDA_HOME = lib.optionalString cudaSupport "${cuda-native-redist}";
102 NVCC_PREPEND_FLAGS = lib.optionals cudaSupport [
103 "-I${cuda-native-redist}/include"
104 "-L${cuda-native-redist}/lib"
105 ];
106
107 preBuild = ''
108 make -j $NIX_BUILD_CORES
109 cd .. # leave /build/source/build
110 '';
111
112 dependencies = [
113 scipy
114 torch
115 ];
116
117 doCheck = false; # tests require CUDA and also GPU access
118
119 pythonImportsCheck = [ "bitsandbytes" ];
120
121 meta = {
122 description = "8-bit CUDA functions for PyTorch";
123 homepage = "https://github.com/bitsandbytes-foundation/bitsandbytes";
124 changelog = "https://github.com/bitsandbytes-foundation/bitsandbytes/releases/tag/${version}";
125 license = lib.licenses.mit;
126 maintainers = with lib.maintainers; [ bcdarwin ];
127 };
128}