1# For the moment we only support the CPU and GPU backends of jaxlib. The TPU 2# backend will require some additional work. Those wheels are located here: 3# https://storage.googleapis.com/jax-releases/libtpu_releases.html. 4 5# See `python3Packages.jax.passthru` for CUDA tests. 6 7{ 8 absl-py, 9 autoPatchelfHook, 10 buildPythonPackage, 11 fetchPypi, 12 flatbuffers, 13 lib, 14 ml-dtypes, 15 python, 16 scipy, 17 stdenv, 18}: 19 20let 21 version = "0.7.2"; 22 inherit (python) pythonVersion; 23 24 # As of 2023-06-06, google/jax upstream is no longer publishing CPU-only wheels to their GCS bucket. Instead the 25 # official instructions recommend installing CPU-only versions via PyPI. 26 srcs = 27 let 28 getSrcFromPypi = 29 { 30 platform, 31 dist, 32 hash, 33 }: 34 fetchPypi { 35 inherit 36 version 37 platform 38 dist 39 hash 40 ; 41 pname = "jaxlib"; 42 format = "wheel"; 43 # See the `disabled` attr comment below. 44 python = dist; 45 abi = dist; 46 }; 47 in 48 { 49 "3.11-x86_64-linux" = getSrcFromPypi { 50 platform = "manylinux_2_27_x86_64"; 51 dist = "cp311"; 52 hash = "sha256-Q4IAYjXM7VnS95WsyYPBvtz7yk/qj5RhMR1hxqeTrmY="; 53 }; 54 "3.11-aarch64-linux" = getSrcFromPypi { 55 platform = "manylinux_2_27_aarch64"; 56 dist = "cp311"; 57 hash = "sha256-jKcAM1H76Mz6L6Wkk+wt+/LfkkQTBs9cO5cFCO7bkqs="; 58 }; 59 "3.11-aarch64-darwin" = getSrcFromPypi { 60 platform = "macosx_11_0_arm64"; 61 dist = "cp311"; 62 hash = "sha256-n7+Qr84w4066LqkppQb1kHvdQGI1gSLeSZzp5nGguh8="; 63 }; 64 65 "3.12-x86_64-linux" = getSrcFromPypi { 66 platform = "manylinux_2_27_x86_64"; 67 dist = "cp312"; 68 hash = "sha256-EfMjGeZizP9mhZ6zk3VwUNiXG9iAvE3XDexkNNiQ+1k="; 69 }; 70 "3.12-aarch64-linux" = getSrcFromPypi { 71 platform = "manylinux_2_27_aarch64"; 72 dist = "cp312"; 73 hash = "sha256-m1oNNXSXYRoRPSB/ssGZfwGrehdYcHAIEiIPC8qjGCI="; 74 }; 75 "3.12-aarch64-darwin" = getSrcFromPypi { 76 platform = "macosx_11_0_arm64"; 77 dist = "cp312"; 78 hash = "sha256-vW0cU71HXg52ilSvmLFkL7SdcwTPBVzuux0B6J04ocs="; 79 }; 80 81 "3.13-x86_64-linux" = getSrcFromPypi { 82 platform = "manylinux_2_27_x86_64"; 83 dist = "cp313"; 84 hash = "sha256-SdmWIEhu/9qHQAAkcjpFIwZWaZbj3nGe5jPwUiDR7nc="; 85 }; 86 "3.13-aarch64-linux" = getSrcFromPypi { 87 platform = "manylinux_2_27_aarch64"; 88 dist = "cp313"; 89 hash = "sha256-l8eT6Xvl3cc7PoXmzorTcJ6AVPdeohnMDLTwgFplrwY="; 90 }; 91 "3.13-aarch64-darwin" = getSrcFromPypi { 92 platform = "macosx_11_0_arm64"; 93 dist = "cp313"; 94 hash = "sha256-4bPf6ZFYJfzgBuoJW4U/V2gYRcW/qAl13MN4iTY3H7A="; 95 }; 96 }; 97in 98buildPythonPackage { 99 pname = "jaxlib"; 100 inherit version; 101 format = "wheel"; 102 103 # See https://discourse.nixos.org/t/ofborg-does-not-respect-meta-platforms/27019/6. 104 src = ( 105 srcs."${pythonVersion}-${stdenv.hostPlatform.system}" 106 or (throw "jaxlib-bin is not supported on ${stdenv.hostPlatform.system}") 107 ); 108 109 # Prebuilt wheels are dynamically linked against things that nix can't find. 110 # Run `autoPatchelfHook` to automagically fix them. 111 nativeBuildInputs = lib.optionals stdenv.hostPlatform.isLinux [ autoPatchelfHook ]; 112 # Dynamic link dependencies 113 buildInputs = [ (lib.getLib stdenv.cc.cc) ]; 114 115 dependencies = [ 116 absl-py 117 flatbuffers 118 ml-dtypes 119 scipy 120 ]; 121 122 pythonImportsCheck = [ "jaxlib" ]; 123 124 meta = { 125 description = "Prebuilt jaxlib backend from PyPi"; 126 homepage = "https://github.com/google/jax"; 127 sourceProvenance = with lib.sourceTypes; [ binaryNativeCode ]; 128 license = lib.licenses.asl20; 129 maintainers = with lib.maintainers; [ samuela ]; 130 badPlatforms = [ 131 # Fails at pythonImportsCheckPhase: 132 # ...-python-imports-check-hook.sh/nix-support/setup-hook: line 10: 28017 Illegal instruction: 4 133 # /nix/store/5qpssbvkzfh73xih07xgmpkj5r565975-python3-3.11.9/bin/python3.11 -c 134 # 'import os; import importlib; list(map(lambda mod: importlib.import_module(mod), os.environ["pythonImportsCheck"].split()))' 135 "x86_64-darwin" 136 ]; 137 }; 138}