1{
2 lib,
3 pkgs,
4 stdenv,
5
6 # Build-time dependencies:
7 addDriverRunpath,
8 autoAddDriverRunpath,
9 bazel_7,
10 binutils,
11 buildBazelPackage,
12 buildPythonPackage,
13 cctools,
14 curl,
15 cython,
16 fetchFromGitHub,
17 git,
18 jsoncpp,
19 nsync,
20 openssl,
21 pybind11,
22 setuptools,
23 symlinkJoin,
24 wheel,
25 build,
26 which,
27
28 # Python dependencies:
29 absl-py,
30 flatbuffers,
31 ml-dtypes,
32 numpy,
33 scipy,
34 six,
35
36 # Runtime dependencies:
37 double-conversion,
38 giflib,
39 libjpeg_turbo,
40 python,
41 snappy-cpp,
42 zlib,
43
44 config,
45 # CUDA flags:
46 cudaSupport ? config.cudaSupport,
47 cudaPackages,
48
49 # MKL:
50 mklSupport ? true,
51}@inputs:
52
53let
54 inherit (cudaPackages)
55 cudaMajorMinorVersion
56 flags
57 nccl
58 ;
59
60 pname = "jaxlib";
61 version = "0.4.28";
62
63 # It's necessary to consistently use backendStdenv when building with CUDA
64 # support, otherwise we get libstdc++ errors downstream
65 stdenv = throw "Use effectiveStdenv instead";
66 effectiveStdenv = if cudaSupport then cudaPackages.backendStdenv else inputs.stdenv;
67
68 meta = with lib; {
69 description = "Source-built JAX backend. JAX is Autograd and XLA, brought together for high-performance machine learning research";
70 homepage = "https://github.com/google/jax";
71 license = licenses.asl20;
72 maintainers = with maintainers; [ ndl ];
73
74 # Make this platforms.unix once Darwin is supported.
75 # The top-level jaxlib now falls back to jaxlib-bin on unsupported platforms.
76 # aarch64-darwin is broken because of https://github.com/bazelbuild/rules_cc/pull/136
77 # however even with that fix applied, it doesn't work for everyone:
78 # https://github.com/NixOS/nixpkgs/pull/184395#issuecomment-1207287129
79 platforms = platforms.linux;
80
81 # Needs update for Bazel 7.
82 broken = true;
83 };
84
85 # Bazel wants a merged cudnn at configuration time
86 cudnnMerged = symlinkJoin {
87 name = "cudnn-merged";
88 paths = with cudaPackages; [
89 (lib.getDev cudnn)
90 (lib.getLib cudnn)
91 ];
92 };
93
94 # These are necessary at build time and run time.
95 cuda_libs_joined = symlinkJoin {
96 name = "cuda-joined";
97 paths = with cudaPackages; [
98 (lib.getLib cuda_cudart) # libcudart.so
99 (lib.getLib cuda_cupti) # libcupti.so
100 (lib.getLib libcublas) # libcublas.so
101 (lib.getLib libcufft) # libcufft.so
102 (lib.getLib libcurand) # libcurand.so
103 (lib.getLib libcusolver) # libcusolver.so
104 (lib.getLib libcusparse) # libcusparse.so
105 ];
106 };
107 # These are only necessary at build time.
108 cuda_build_deps_joined = symlinkJoin {
109 name = "cuda-build-deps-joined";
110 paths = with cudaPackages; [
111 cuda_libs_joined
112
113 # Binaries
114 (lib.getBin cuda_nvcc) # nvcc
115
116 # Archives
117 (lib.getOutput "static" cuda_cudart) # libcudart_static.a
118
119 # Headers
120 (lib.getDev cuda_cccl) # block_load.cuh
121 (lib.getDev cuda_cudart) # cuda.h
122 (lib.getDev cuda_cupti) # cupti.h
123 (lib.getDev cuda_nvcc) # See https://github.com/google/jax/issues/19811
124 (lib.getDev cuda_nvml_dev) # nvml.h
125 (lib.getDev cuda_nvtx) # nvToolsExt.h
126 (lib.getDev libcublas) # cublas_api.h
127 (lib.getDev libcufft) # cufft.h
128 (lib.getDev libcurand) # curand.h
129 (lib.getDev libcusolver) # cusolver_common.h
130 (lib.getDev libcusparse) # cusparse.h
131 ];
132 };
133
134 backend_cc_joined = symlinkJoin {
135 name = "cuda-cc-joined";
136 paths = [
137 effectiveStdenv.cc
138 binutils.bintools # for ar, dwp, nm, objcopy, objdump, strip
139 ];
140 };
141
142 # Copy-paste from TF derivation.
143 # Most of these are not really used in jaxlib compilation but it's simpler to keep it
144 # 'as is' so that it's more compatible with TF derivation.
145 tf_system_libs = [
146 "absl_py"
147 "astor_archive"
148 "astunparse_archive"
149 # Not packaged in nixpkgs
150 # "com_github_googleapis_googleapis"
151 # "com_github_googlecloudplatform_google_cloud_cpp"
152 # Issue with transitive dependencies after https://github.com/grpc/grpc/commit/f1d14f7f0b661bd200b7f269ef55dec870e7c108
153 # "com_github_grpc_grpc"
154 # ERROR: /build/output/external/bazel_tools/tools/proto/BUILD:25:6: no such target '@com_google_protobuf//:cc_toolchain':
155 # target 'cc_toolchain' not declared in package '' defined by /build/output/external/com_google_protobuf/BUILD.bazel
156 # "com_google_protobuf"
157 # Fails with the error: external/org_tensorflow/tensorflow/core/profiler/utils/tf_op_utils.cc:46:49: error: no matching function for call to 're2::RE2::FullMatch(absl::lts_2020_02_25::string_view&, re2::RE2&)'
158 # "com_googlesource_code_re2"
159 "curl"
160 "cython"
161 "dill_archive"
162 "double_conversion"
163 "flatbuffers"
164 "functools32_archive"
165 "gast_archive"
166 "gif"
167 "hwloc"
168 "icu"
169 "jsoncpp_git"
170 "libjpeg_turbo"
171 "lmdb"
172 "nasm"
173 "opt_einsum_archive"
174 "org_sqlite"
175 "pasta"
176 "png"
177 # ERROR: /build/output/external/pybind11/BUILD.bazel: no such target '@pybind11//:osx':
178 # target 'osx' not declared in package '' defined by /build/output/external/pybind11/BUILD.bazel
179 # "pybind11"
180 "six_archive"
181 "snappy"
182 "tblib_archive"
183 "termcolor_archive"
184 "typing_extensions_archive"
185 "wrapt"
186 "zlib"
187 ];
188
189 arch =
190 # KeyError: ('Linux', 'arm64')
191 if effectiveStdenv.hostPlatform.isLinux && effectiveStdenv.hostPlatform.linuxArch == "arm64" then
192 "aarch64"
193 else
194 effectiveStdenv.hostPlatform.linuxArch;
195
196 xla = effectiveStdenv.mkDerivation {
197 pname = "xla-src";
198 version = "unstable";
199
200 src = fetchFromGitHub {
201 owner = "openxla";
202 repo = "xla";
203 # Update this according to https://github.com/google/jax/blob/jaxlib-v${version}/third_party/xla/workspace.bzl.
204 rev = "e8247c3ea1d4d7f31cf27def4c7ac6f2ce64ecd4";
205 hash = "sha256-ZhgMIVs3Z4dTrkRWDqaPC/i7yJz2dsYXrZbjzqvPX3E=";
206 };
207
208 dontBuild = true;
209
210 # This is necessary for patchShebangs to know the right path to use.
211 nativeBuildInputs = [ python ];
212
213 # Main culprits we're targeting are third_party/tsl/third_party/gpus/crosstool/clang/bin/*.tpl
214 postPatch = ''
215 patchShebangs .
216 '';
217
218 installPhase = ''
219 cp -r . $out
220 '';
221 };
222
223 bazel-build = buildBazelPackage rec {
224 name = "bazel-build-${pname}-${version}";
225
226 # See https://github.com/google/jax/blob/main/.bazelversion for the latest.
227 #bazel = bazel_6;
228 bazel = bazel_7;
229
230 src = fetchFromGitHub {
231 owner = "google";
232 repo = "jax";
233 # google/jax contains tags for jax and jaxlib. Only use jaxlib tags!
234 rev = "refs/tags/${pname}-v${version}";
235 hash = "sha256-qSHPwi3is6Ts7pz5s4KzQHBMbcjGp+vAOsejW3o36Ek=";
236 };
237
238 nativeBuildInputs = [
239 cython
240 pkgs.flatbuffers
241 git
242 setuptools
243 wheel
244 build
245 which
246 ]
247 ++ lib.optionals effectiveStdenv.hostPlatform.isDarwin [ cctools ];
248
249 buildInputs = [
250 curl
251 double-conversion
252 giflib
253 jsoncpp
254 libjpeg_turbo
255 numpy
256 openssl
257 pkgs.flatbuffers
258 pkgs.protobuf
259 pybind11
260 scipy
261 six
262 snappy-cpp
263 zlib
264 ]
265 ++ lib.optionals (!effectiveStdenv.hostPlatform.isDarwin) [ nsync ];
266
267 # We don't want to be quite so picky regarding bazel version
268 postPatch = ''
269 rm -f .bazelversion
270 '';
271
272 bazelRunTarget = "//jaxlib/tools:build_wheel";
273 runTargetFlags = [
274 "--output_path=$out"
275 "--cpu=${arch}"
276 # This has no impact whatsoever...
277 "--jaxlib_git_hash='12345678'"
278 ];
279
280 removeRulesCC = false;
281
282 GCC_HOST_COMPILER_PREFIX = lib.optionalString cudaSupport "${backend_cc_joined}/bin";
283 GCC_HOST_COMPILER_PATH = lib.optionalString cudaSupport "${backend_cc_joined}/bin/gcc";
284
285 # The version is automatically set to ".dev" if this variable is not set.
286 # https://github.com/google/jax/commit/e01f2617b85c5bdffc5ffb60b3d8d8ca9519a1f3
287 JAXLIB_RELEASE = "1";
288
289 preConfigure =
290 # Dummy ldconfig to work around "Can't open cache file /nix/store/<hash>-glibc-2.38-44/etc/ld.so.cache" error
291 ''
292 mkdir dummy-ldconfig
293 echo "#!${effectiveStdenv.shell}" > dummy-ldconfig/ldconfig
294 chmod +x dummy-ldconfig/ldconfig
295 export PATH="$PWD/dummy-ldconfig:$PATH"
296 ''
297 +
298
299 # Construct .jax_configure.bazelrc. See https://github.com/google/jax/blob/b9824d7de3cb30f1df738cc42e486db3e9d915ff/build/build.py#L259-L345
300 # for more info. We assume
301 # * `cpu = None`
302 # * `enable_nccl = True`
303 # * `target_cpu_features = "release"`
304 # * `rocm_amdgpu_targets = None`
305 # * `enable_rocm = False`
306 # * `build_gpu_plugin = False`
307 # * `use_clang = False` (Should we use `effectiveStdenv.cc.isClang` instead?)
308 #
309 # Note: We should try just running https://github.com/google/jax/blob/ceb198582b62b9e6f6bdf20ab74839b0cf1db16e/build/build.py#L259-L266
310 # instead of duplicating the logic here. Perhaps we can leverage the
311 # `--configure_only` flag (https://github.com/google/jax/blob/ceb198582b62b9e6f6bdf20ab74839b0cf1db16e/build/build.py#L544-L548)?
312 ''
313 cat <<CFG > ./.jax_configure.bazelrc
314 build --strategy=Genrule=standalone
315 build --repo_env PYTHON_BIN_PATH="${python}/bin/python"
316 build --action_env=PYENV_ROOT
317 build --python_path="${python}/bin/python"
318 build --distinct_host_configuration=false
319 build --define PROTOBUF_INCLUDE_PATH="${pkgs.protobuf}/include"
320 ''
321 + lib.optionalString cudaSupport ''
322 build --config=cuda
323 build --action_env CUDA_TOOLKIT_PATH="${cuda_build_deps_joined}"
324 build --action_env CUDNN_INSTALL_PATH="${cudnnMerged}"
325 build --action_env TF_CUDA_PATHS="${cuda_build_deps_joined},${cudnnMerged},${lib.getDev nccl}"
326 build --action_env TF_CUDA_VERSION="${cudaMajorMinorVersion}"
327 build --action_env TF_CUDNN_VERSION="${lib.versions.major cudaPackages.cudnn.version}"
328 build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="${builtins.concatStringsSep "," flags.realArches}"
329 ''
330 +
331 # Note that upstream conditions this on `wheel_cpu == "x86_64"`. We just
332 # rely on `effectiveStdenv.hostPlatform.avxSupport` instead. So far so
333 # good. See https://github.com/google/jax/blob/b9824d7de3cb30f1df738cc42e486db3e9d915ff/build/build.py#L322
334 # for upstream's version.
335 lib.optionalString (effectiveStdenv.hostPlatform.avxSupport && effectiveStdenv.hostPlatform.isUnix)
336 ''
337 build --config=avx_posix
338 ''
339 + lib.optionalString mklSupport ''
340 build --config=mkl_open_source_only
341 ''
342 + ''
343 CFG
344 '';
345
346 # Make sure Bazel knows about our configuration flags during fetching so that the
347 # relevant dependencies can be downloaded.
348 bazelFlags = [
349 "-c opt"
350 # See https://bazel.build/external/advanced#overriding-repositories for
351 # information on --override_repository flag.
352 "--override_repository=xla=${xla}"
353 ]
354 ++ lib.optionals effectiveStdenv.cc.isClang [
355 # bazel depends on the compiler frontend automatically selecting these flags based on file
356 # extension but our clang doesn't.
357 # https://github.com/NixOS/nixpkgs/issues/150655
358 "--cxxopt=-x"
359 "--cxxopt=c++"
360 "--host_cxxopt=-x"
361 "--host_cxxopt=c++"
362 ];
363
364 # We intentionally overfetch so we can share the fetch derivation across all the different configurations
365 fetchAttrs = {
366 TF_SYSTEM_LIBS = lib.concatStringsSep "," tf_system_libs;
367 # we have to force @mkl_dnn_v1 since it's not needed on darwin
368 bazelTargets = [
369 bazelRunTarget
370 "@mkl_dnn_v1//:mkl_dnn"
371 ];
372 bazelFlags =
373 bazelFlags
374 ++ [
375 "--config=avx_posix"
376 "--config=mkl_open_source_only"
377 ]
378 ++ lib.optionals cudaSupport [
379 # ideally we'd add this unconditionally too, but it doesn't work on darwin
380 # we make this conditional on `cudaSupport` instead of the system, so that the hash for both
381 # the cuda and the non-cuda deps can be computed on linux, since a lot of contributors don't
382 # have access to darwin machines
383 "--config=cuda"
384 ];
385
386 sha256 =
387 (
388 if cudaSupport then
389 { x86_64-linux = "sha256-Uf0VMRE0jgaWEYiuphWkWloZ5jMeqaWBl3lSvk2y1HI="; }
390 else
391 {
392 x86_64-linux = "sha256-NzJJg6NlrPGMiR8Fn8u4+fu0m+AulfmN5Xqk63Um6sw=";
393 aarch64-linux = "sha256-Ro3qzrUxSR+3TH6ROoJTq+dLSufrDN/9oEo2MRkx7wM=";
394 }
395 ).${effectiveStdenv.system} or (throw "jaxlib: unsupported system: ${effectiveStdenv.system}");
396
397 # Non-reproducible fetch https://github.com/NixOS/nixpkgs/issues/321920#issuecomment-2184940546
398 preInstall = ''
399 cat << \EOF > "$bazelOut/external/go_sdk/versions.json"
400 []
401 EOF
402 '';
403 };
404
405 buildAttrs = {
406 outputs = [ "out" ];
407
408 TF_SYSTEM_LIBS = lib.concatStringsSep "," (
409 tf_system_libs
410 ++ lib.optionals (!effectiveStdenv.hostPlatform.isDarwin) [
411 "nsync" # fails to build on darwin
412 ]
413 );
414
415 # Note: we cannot do most of this patching at `patch` phase as the deps
416 # are not available yet.
417 preBuild = lib.optionalString effectiveStdenv.hostPlatform.isDarwin ''
418 substituteInPlace ../output/external/rules_cc/cc/private/toolchain/osx_cc_wrapper.sh.tpl \
419 --replace "/usr/bin/install_name_tool" "${cctools}/bin/install_name_tool"
420 substituteInPlace ../output/external/rules_cc/cc/private/toolchain/unix_cc_configure.bzl \
421 --replace "/usr/bin/libtool" "${cctools}/bin/libtool"
422 '';
423 };
424
425 inherit meta;
426 };
427 platformTag =
428 if effectiveStdenv.hostPlatform.isLinux then
429 "manylinux2014_${arch}"
430 else if effectiveStdenv.system == "x86_64-darwin" then
431 "macosx_10_9_${arch}"
432 else if effectiveStdenv.system == "aarch64-darwin" then
433 "macosx_11_0_${arch}"
434 else
435 throw "Unsupported target platform: ${effectiveStdenv.hostPlatform}";
436in
437buildPythonPackage {
438 inherit pname version;
439 format = "wheel";
440
441 src =
442 let
443 cp = "cp${builtins.replaceStrings [ "." ] [ "" ] python.pythonVersion}";
444 in
445 "${bazel-build}/jaxlib-${version}-${cp}-${cp}-${platformTag}.whl";
446
447 # Note that jaxlib looks for "ptxas" in $PATH. See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621
448 # for more info.
449 postInstall = lib.optionalString cudaSupport ''
450 mkdir -p $out/bin
451 ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/bin/ptxas
452
453 find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do
454 patchelf --add-rpath "${
455 lib.makeLibraryPath [
456 cuda_libs_joined
457 (lib.getLib cudaPackages.cudnn)
458 nccl
459 ]
460 }" "$lib"
461 done
462 '';
463
464 nativeBuildInputs = lib.optionals cudaSupport [ autoAddDriverRunpath ];
465
466 dependencies = [
467 absl-py
468 curl
469 double-conversion
470 flatbuffers
471 giflib
472 jsoncpp
473 libjpeg_turbo
474 ml-dtypes
475 numpy
476 scipy
477 six
478 ];
479
480 buildInputs = [
481 snappy-cpp
482 ];
483
484 pythonImportsCheck = [
485 "jaxlib"
486 # `import jaxlib` loads surprisingly little. These imports are actually bugs that appeared in the 0.4.11 upgrade.
487 "jaxlib.cpu_feature_guard"
488 "jaxlib.xla_client"
489 ];
490
491 # Without it there are complaints about libcudart.so.11.0 not being found
492 # because RPATH path entries added above are stripped.
493 dontPatchELF = cudaSupport;
494
495 passthru = {
496 # Note "bazel.*.tar.gz" can be accessed as `jaxlib.bazel-build.deps`
497 inherit bazel-build;
498 };
499
500 inherit meta;
501}