1{
2 lib,
3 stdenv,
4 buildPythonPackage,
5 fetchPypi,
6 autoPatchelfHook,
7 pypaInstallHook,
8 wheelUnpackHook,
9 cudaPackages,
10 python,
11 jaxlib,
12 jax-cuda12-pjrt,
13}:
14let
15 inherit (jaxlib) version;
16 inherit (jax-cuda12-pjrt) cudaLibPath;
17
18 getSrcFromPypi =
19 {
20 platform,
21 dist,
22 hash,
23 }:
24 fetchPypi {
25 inherit
26 version
27 platform
28 dist
29 hash
30 ;
31 pname = "jax_cuda12_plugin";
32 format = "wheel";
33 python = dist;
34 abi = dist;
35 };
36
37 # upstream does not distribute jax-cuda12-plugin 0.4.38 binaries for aarch64-linux
38 srcs = {
39 "3.11-x86_64-linux" = getSrcFromPypi {
40 platform = "manylinux_2_27_x86_64";
41 dist = "cp311";
42 hash = "sha256-rckk68ekXI00AOoBGNxwpwgrKobjVxFzjUA904FdCb8=";
43 };
44 "3.11-aarch64-linux" = getSrcFromPypi {
45 platform = "manylinux_2_27_aarch64";
46 dist = "cp311";
47 hash = "sha256-KnJ6ia5prCHB9Qk9jVrviaDmkuZrA0/JNMiszHLkApA=";
48 };
49 "3.12-x86_64-linux" = getSrcFromPypi {
50 platform = "manylinux_2_27_x86_64";
51 dist = "cp312";
52 hash = "sha256-goTnz39USQZgTxEXAqbwARqW338BE4eLOBvsCQUXJTY=";
53 };
54 "3.12-aarch64-linux" = getSrcFromPypi {
55 platform = "manylinux_2_27_aarch64";
56 dist = "cp312";
57 hash = "sha256-mKl1ZVOChY2HTWRxzpcZQxBgnQoqfEKDxuB+N5M7d2g=";
58 };
59 "3.13-x86_64-linux" = getSrcFromPypi {
60 platform = "manylinux_2_27_x86_64";
61 dist = "cp313";
62 hash = "sha256-chLBLXW33FEnXycYJ99KbTeEMMBvZQ5sMcFi/pV5/xI=";
63 };
64 "3.13-aarch64-linux" = getSrcFromPypi {
65 platform = "manylinux_2_27_aarch64";
66 dist = "cp313";
67 hash = "sha256-Xj4qpNch+wLdECgmKq6uwpWORbylxNNRKykVG1cMtCU=";
68 };
69 };
70in
71buildPythonPackage {
72 pname = "jax-cuda12-plugin";
73 inherit version;
74 pyproject = false;
75
76 src = (
77 srcs."${python.pythonVersion}-${stdenv.hostPlatform.system}"
78 or (throw "python${python.pythonVersion}Packages.jax-cuda12-plugin is not supported on ${stdenv.hostPlatform.system}")
79 );
80
81 nativeBuildInputs = [
82 autoPatchelfHook
83 pypaInstallHook
84 wheelUnpackHook
85 ];
86
87 # jax-cuda12-plugin looks for ptxas at runtime, e.g. with a triton kernel.
88 # Linking into $out is the least bad solution. See
89 # * https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621
90 # * https://github.com/NixOS/nixpkgs/pull/288829#discussion_r1493852211
91 # * https://github.com/NixOS/nixpkgs/pull/375186
92 # for more info.
93 postInstall = ''
94 mkdir -p $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin
95 ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin
96 ln -s ${lib.getExe' cudaPackages.cuda_nvcc "nvlink"} $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin
97 '';
98
99 # jax-cuda12-plugin contains shared libraries that open other shared libraries via dlopen
100 # and these implicit dependencies are not recognized by ldd or
101 # autoPatchelfHook. That means we need to sneak them into rpath. This step
102 # must be done after autoPatchelfHook and the automatic stripping of
103 # artifacts. autoPatchelfHook runs in postFixup and auto-stripping runs in the
104 # patchPhase.
105 preInstallCheck = ''
106 patchelf --add-rpath "${cudaLibPath}" $out/${python.sitePackages}/jax_cuda12_plugin/*.so
107 '';
108
109 dependencies = [ jax-cuda12-pjrt ];
110
111 pythonImportsCheck = [ "jax_cuda12_plugin" ];
112
113 # FIXME: there are no tests, but we need to run preInstallCheck above
114 doCheck = true;
115
116 meta = {
117 description = "JAX Plugin for CUDA12";
118 homepage = "https://github.com/jax-ml/jax/tree/main/jax_plugins/cuda";
119 sourceProvenance = [ lib.sourceTypes.binaryNativeCode ];
120 license = lib.licenses.asl20;
121 maintainers = with lib.maintainers; [ natsukium ];
122 platforms = lib.platforms.linux;
123 # see CUDA compatibility matrix
124 # https://jax.readthedocs.io/en/latest/installation.html#pip-installation-nvidia-gpu-cuda-installed-locally-harder
125 broken = !(lib.versionAtLeast cudaPackages.cudnn.version "9.1");
126 };
127}