1{ 2 lib, 3 stdenv, 4 buildPythonPackage, 5 fetchPypi, 6 autoPatchelfHook, 7 pypaInstallHook, 8 wheelUnpackHook, 9 cudaPackages, 10 python, 11 jaxlib, 12 jax-cuda12-pjrt, 13}: 14let 15 inherit (jaxlib) version; 16 inherit (jax-cuda12-pjrt) cudaLibPath; 17 18 getSrcFromPypi = 19 { 20 platform, 21 dist, 22 hash, 23 }: 24 fetchPypi { 25 inherit 26 version 27 platform 28 dist 29 hash 30 ; 31 pname = "jax_cuda12_plugin"; 32 format = "wheel"; 33 python = dist; 34 abi = dist; 35 }; 36 37 # upstream does not distribute jax-cuda12-plugin 0.4.38 binaries for aarch64-linux 38 srcs = { 39 "3.11-x86_64-linux" = getSrcFromPypi { 40 platform = "manylinux_2_27_x86_64"; 41 dist = "cp311"; 42 hash = "sha256-rckk68ekXI00AOoBGNxwpwgrKobjVxFzjUA904FdCb8="; 43 }; 44 "3.11-aarch64-linux" = getSrcFromPypi { 45 platform = "manylinux_2_27_aarch64"; 46 dist = "cp311"; 47 hash = "sha256-KnJ6ia5prCHB9Qk9jVrviaDmkuZrA0/JNMiszHLkApA="; 48 }; 49 "3.12-x86_64-linux" = getSrcFromPypi { 50 platform = "manylinux_2_27_x86_64"; 51 dist = "cp312"; 52 hash = "sha256-goTnz39USQZgTxEXAqbwARqW338BE4eLOBvsCQUXJTY="; 53 }; 54 "3.12-aarch64-linux" = getSrcFromPypi { 55 platform = "manylinux_2_27_aarch64"; 56 dist = "cp312"; 57 hash = "sha256-mKl1ZVOChY2HTWRxzpcZQxBgnQoqfEKDxuB+N5M7d2g="; 58 }; 59 "3.13-x86_64-linux" = getSrcFromPypi { 60 platform = "manylinux_2_27_x86_64"; 61 dist = "cp313"; 62 hash = "sha256-chLBLXW33FEnXycYJ99KbTeEMMBvZQ5sMcFi/pV5/xI="; 63 }; 64 "3.13-aarch64-linux" = getSrcFromPypi { 65 platform = "manylinux_2_27_aarch64"; 66 dist = "cp313"; 67 hash = "sha256-Xj4qpNch+wLdECgmKq6uwpWORbylxNNRKykVG1cMtCU="; 68 }; 69 }; 70in 71buildPythonPackage { 72 pname = "jax-cuda12-plugin"; 73 inherit version; 74 pyproject = false; 75 76 src = ( 77 srcs."${python.pythonVersion}-${stdenv.hostPlatform.system}" 78 or (throw "python${python.pythonVersion}Packages.jax-cuda12-plugin is not supported on ${stdenv.hostPlatform.system}") 79 ); 80 81 nativeBuildInputs = [ 82 autoPatchelfHook 83 pypaInstallHook 84 wheelUnpackHook 85 ]; 86 87 # jax-cuda12-plugin looks for ptxas at runtime, e.g. with a triton kernel. 88 # Linking into $out is the least bad solution. See 89 # * https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 90 # * https://github.com/NixOS/nixpkgs/pull/288829#discussion_r1493852211 91 # * https://github.com/NixOS/nixpkgs/pull/375186 92 # for more info. 93 postInstall = '' 94 mkdir -p $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin 95 ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin 96 ln -s ${lib.getExe' cudaPackages.cuda_nvcc "nvlink"} $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin 97 ''; 98 99 # jax-cuda12-plugin contains shared libraries that open other shared libraries via dlopen 100 # and these implicit dependencies are not recognized by ldd or 101 # autoPatchelfHook. That means we need to sneak them into rpath. This step 102 # must be done after autoPatchelfHook and the automatic stripping of 103 # artifacts. autoPatchelfHook runs in postFixup and auto-stripping runs in the 104 # patchPhase. 105 preInstallCheck = '' 106 patchelf --add-rpath "${cudaLibPath}" $out/${python.sitePackages}/jax_cuda12_plugin/*.so 107 ''; 108 109 dependencies = [ jax-cuda12-pjrt ]; 110 111 pythonImportsCheck = [ "jax_cuda12_plugin" ]; 112 113 # FIXME: there are no tests, but we need to run preInstallCheck above 114 doCheck = true; 115 116 meta = { 117 description = "JAX Plugin for CUDA12"; 118 homepage = "https://github.com/jax-ml/jax/tree/main/jax_plugins/cuda"; 119 sourceProvenance = [ lib.sourceTypes.binaryNativeCode ]; 120 license = lib.licenses.asl20; 121 maintainers = with lib.maintainers; [ natsukium ]; 122 platforms = lib.platforms.linux; 123 # see CUDA compatibility matrix 124 # https://jax.readthedocs.io/en/latest/installation.html#pip-installation-nvidia-gpu-cuda-installed-locally-harder 125 broken = !(lib.versionAtLeast cudaPackages.cudnn.version "9.1"); 126 }; 127}