at master 16 kB view raw
1{ 2 lib, 3 pkgs, 4 stdenv, 5 6 # Build-time dependencies: 7 addDriverRunpath, 8 autoAddDriverRunpath, 9 bazel_7, 10 binutils, 11 buildBazelPackage, 12 buildPythonPackage, 13 cctools, 14 curl, 15 cython, 16 fetchFromGitHub, 17 git, 18 jsoncpp, 19 nsync, 20 openssl, 21 pybind11, 22 setuptools, 23 symlinkJoin, 24 wheel, 25 build, 26 which, 27 28 # Python dependencies: 29 absl-py, 30 flatbuffers, 31 ml-dtypes, 32 numpy, 33 scipy, 34 six, 35 36 # Runtime dependencies: 37 double-conversion, 38 giflib, 39 libjpeg_turbo, 40 python, 41 snappy-cpp, 42 zlib, 43 44 config, 45 # CUDA flags: 46 cudaSupport ? config.cudaSupport, 47 cudaPackages, 48 49 # MKL: 50 mklSupport ? true, 51}@inputs: 52 53let 54 inherit (cudaPackages) 55 cudaMajorMinorVersion 56 flags 57 nccl 58 ; 59 60 pname = "jaxlib"; 61 version = "0.4.28"; 62 63 # It's necessary to consistently use backendStdenv when building with CUDA 64 # support, otherwise we get libstdc++ errors downstream 65 stdenv = throw "Use effectiveStdenv instead"; 66 effectiveStdenv = if cudaSupport then cudaPackages.backendStdenv else inputs.stdenv; 67 68 meta = with lib; { 69 description = "Source-built JAX backend. JAX is Autograd and XLA, brought together for high-performance machine learning research"; 70 homepage = "https://github.com/google/jax"; 71 license = licenses.asl20; 72 maintainers = with maintainers; [ ndl ]; 73 74 # Make this platforms.unix once Darwin is supported. 75 # The top-level jaxlib now falls back to jaxlib-bin on unsupported platforms. 76 # aarch64-darwin is broken because of https://github.com/bazelbuild/rules_cc/pull/136 77 # however even with that fix applied, it doesn't work for everyone: 78 # https://github.com/NixOS/nixpkgs/pull/184395#issuecomment-1207287129 79 platforms = platforms.linux; 80 81 # Needs update for Bazel 7. 82 broken = true; 83 }; 84 85 # Bazel wants a merged cudnn at configuration time 86 cudnnMerged = symlinkJoin { 87 name = "cudnn-merged"; 88 paths = with cudaPackages; [ 89 (lib.getDev cudnn) 90 (lib.getLib cudnn) 91 ]; 92 }; 93 94 # These are necessary at build time and run time. 95 cuda_libs_joined = symlinkJoin { 96 name = "cuda-joined"; 97 paths = with cudaPackages; [ 98 (lib.getLib cuda_cudart) # libcudart.so 99 (lib.getLib cuda_cupti) # libcupti.so 100 (lib.getLib libcublas) # libcublas.so 101 (lib.getLib libcufft) # libcufft.so 102 (lib.getLib libcurand) # libcurand.so 103 (lib.getLib libcusolver) # libcusolver.so 104 (lib.getLib libcusparse) # libcusparse.so 105 ]; 106 }; 107 # These are only necessary at build time. 108 cuda_build_deps_joined = symlinkJoin { 109 name = "cuda-build-deps-joined"; 110 paths = with cudaPackages; [ 111 cuda_libs_joined 112 113 # Binaries 114 (lib.getBin cuda_nvcc) # nvcc 115 116 # Archives 117 (lib.getOutput "static" cuda_cudart) # libcudart_static.a 118 119 # Headers 120 (lib.getDev cuda_cccl) # block_load.cuh 121 (lib.getDev cuda_cudart) # cuda.h 122 (lib.getDev cuda_cupti) # cupti.h 123 (lib.getDev cuda_nvcc) # See https://github.com/google/jax/issues/19811 124 (lib.getDev cuda_nvml_dev) # nvml.h 125 (lib.getDev cuda_nvtx) # nvToolsExt.h 126 (lib.getDev libcublas) # cublas_api.h 127 (lib.getDev libcufft) # cufft.h 128 (lib.getDev libcurand) # curand.h 129 (lib.getDev libcusolver) # cusolver_common.h 130 (lib.getDev libcusparse) # cusparse.h 131 ]; 132 }; 133 134 backend_cc_joined = symlinkJoin { 135 name = "cuda-cc-joined"; 136 paths = [ 137 effectiveStdenv.cc 138 binutils.bintools # for ar, dwp, nm, objcopy, objdump, strip 139 ]; 140 }; 141 142 # Copy-paste from TF derivation. 143 # Most of these are not really used in jaxlib compilation but it's simpler to keep it 144 # 'as is' so that it's more compatible with TF derivation. 145 tf_system_libs = [ 146 "absl_py" 147 "astor_archive" 148 "astunparse_archive" 149 # Not packaged in nixpkgs 150 # "com_github_googleapis_googleapis" 151 # "com_github_googlecloudplatform_google_cloud_cpp" 152 # Issue with transitive dependencies after https://github.com/grpc/grpc/commit/f1d14f7f0b661bd200b7f269ef55dec870e7c108 153 # "com_github_grpc_grpc" 154 # ERROR: /build/output/external/bazel_tools/tools/proto/BUILD:25:6: no such target '@com_google_protobuf//:cc_toolchain': 155 # target 'cc_toolchain' not declared in package '' defined by /build/output/external/com_google_protobuf/BUILD.bazel 156 # "com_google_protobuf" 157 # Fails with the error: external/org_tensorflow/tensorflow/core/profiler/utils/tf_op_utils.cc:46:49: error: no matching function for call to 're2::RE2::FullMatch(absl::lts_2020_02_25::string_view&, re2::RE2&)' 158 # "com_googlesource_code_re2" 159 "curl" 160 "cython" 161 "dill_archive" 162 "double_conversion" 163 "flatbuffers" 164 "functools32_archive" 165 "gast_archive" 166 "gif" 167 "hwloc" 168 "icu" 169 "jsoncpp_git" 170 "libjpeg_turbo" 171 "lmdb" 172 "nasm" 173 "opt_einsum_archive" 174 "org_sqlite" 175 "pasta" 176 "png" 177 # ERROR: /build/output/external/pybind11/BUILD.bazel: no such target '@pybind11//:osx': 178 # target 'osx' not declared in package '' defined by /build/output/external/pybind11/BUILD.bazel 179 # "pybind11" 180 "six_archive" 181 "snappy" 182 "tblib_archive" 183 "termcolor_archive" 184 "typing_extensions_archive" 185 "wrapt" 186 "zlib" 187 ]; 188 189 arch = 190 # KeyError: ('Linux', 'arm64') 191 if effectiveStdenv.hostPlatform.isLinux && effectiveStdenv.hostPlatform.linuxArch == "arm64" then 192 "aarch64" 193 else 194 effectiveStdenv.hostPlatform.linuxArch; 195 196 xla = effectiveStdenv.mkDerivation { 197 pname = "xla-src"; 198 version = "unstable"; 199 200 src = fetchFromGitHub { 201 owner = "openxla"; 202 repo = "xla"; 203 # Update this according to https://github.com/google/jax/blob/jaxlib-v${version}/third_party/xla/workspace.bzl. 204 rev = "e8247c3ea1d4d7f31cf27def4c7ac6f2ce64ecd4"; 205 hash = "sha256-ZhgMIVs3Z4dTrkRWDqaPC/i7yJz2dsYXrZbjzqvPX3E="; 206 }; 207 208 dontBuild = true; 209 210 # This is necessary for patchShebangs to know the right path to use. 211 nativeBuildInputs = [ python ]; 212 213 # Main culprits we're targeting are third_party/tsl/third_party/gpus/crosstool/clang/bin/*.tpl 214 postPatch = '' 215 patchShebangs . 216 ''; 217 218 installPhase = '' 219 cp -r . $out 220 ''; 221 }; 222 223 bazel-build = buildBazelPackage rec { 224 name = "bazel-build-${pname}-${version}"; 225 226 # See https://github.com/google/jax/blob/main/.bazelversion for the latest. 227 #bazel = bazel_6; 228 bazel = bazel_7; 229 230 src = fetchFromGitHub { 231 owner = "google"; 232 repo = "jax"; 233 # google/jax contains tags for jax and jaxlib. Only use jaxlib tags! 234 rev = "refs/tags/${pname}-v${version}"; 235 hash = "sha256-qSHPwi3is6Ts7pz5s4KzQHBMbcjGp+vAOsejW3o36Ek="; 236 }; 237 238 nativeBuildInputs = [ 239 cython 240 pkgs.flatbuffers 241 git 242 setuptools 243 wheel 244 build 245 which 246 ] 247 ++ lib.optionals effectiveStdenv.hostPlatform.isDarwin [ cctools ]; 248 249 buildInputs = [ 250 curl 251 double-conversion 252 giflib 253 jsoncpp 254 libjpeg_turbo 255 numpy 256 openssl 257 pkgs.flatbuffers 258 pkgs.protobuf 259 pybind11 260 scipy 261 six 262 snappy-cpp 263 zlib 264 ] 265 ++ lib.optionals (!effectiveStdenv.hostPlatform.isDarwin) [ nsync ]; 266 267 # We don't want to be quite so picky regarding bazel version 268 postPatch = '' 269 rm -f .bazelversion 270 ''; 271 272 bazelRunTarget = "//jaxlib/tools:build_wheel"; 273 runTargetFlags = [ 274 "--output_path=$out" 275 "--cpu=${arch}" 276 # This has no impact whatsoever... 277 "--jaxlib_git_hash='12345678'" 278 ]; 279 280 removeRulesCC = false; 281 282 GCC_HOST_COMPILER_PREFIX = lib.optionalString cudaSupport "${backend_cc_joined}/bin"; 283 GCC_HOST_COMPILER_PATH = lib.optionalString cudaSupport "${backend_cc_joined}/bin/gcc"; 284 285 # The version is automatically set to ".dev" if this variable is not set. 286 # https://github.com/google/jax/commit/e01f2617b85c5bdffc5ffb60b3d8d8ca9519a1f3 287 JAXLIB_RELEASE = "1"; 288 289 preConfigure = 290 # Dummy ldconfig to work around "Can't open cache file /nix/store/<hash>-glibc-2.38-44/etc/ld.so.cache" error 291 '' 292 mkdir dummy-ldconfig 293 echo "#!${effectiveStdenv.shell}" > dummy-ldconfig/ldconfig 294 chmod +x dummy-ldconfig/ldconfig 295 export PATH="$PWD/dummy-ldconfig:$PATH" 296 '' 297 + 298 299 # Construct .jax_configure.bazelrc. See https://github.com/google/jax/blob/b9824d7de3cb30f1df738cc42e486db3e9d915ff/build/build.py#L259-L345 300 # for more info. We assume 301 # * `cpu = None` 302 # * `enable_nccl = True` 303 # * `target_cpu_features = "release"` 304 # * `rocm_amdgpu_targets = None` 305 # * `enable_rocm = False` 306 # * `build_gpu_plugin = False` 307 # * `use_clang = False` (Should we use `effectiveStdenv.cc.isClang` instead?) 308 # 309 # Note: We should try just running https://github.com/google/jax/blob/ceb198582b62b9e6f6bdf20ab74839b0cf1db16e/build/build.py#L259-L266 310 # instead of duplicating the logic here. Perhaps we can leverage the 311 # `--configure_only` flag (https://github.com/google/jax/blob/ceb198582b62b9e6f6bdf20ab74839b0cf1db16e/build/build.py#L544-L548)? 312 '' 313 cat <<CFG > ./.jax_configure.bazelrc 314 build --strategy=Genrule=standalone 315 build --repo_env PYTHON_BIN_PATH="${python}/bin/python" 316 build --action_env=PYENV_ROOT 317 build --python_path="${python}/bin/python" 318 build --distinct_host_configuration=false 319 build --define PROTOBUF_INCLUDE_PATH="${pkgs.protobuf}/include" 320 '' 321 + lib.optionalString cudaSupport '' 322 build --config=cuda 323 build --action_env CUDA_TOOLKIT_PATH="${cuda_build_deps_joined}" 324 build --action_env CUDNN_INSTALL_PATH="${cudnnMerged}" 325 build --action_env TF_CUDA_PATHS="${cuda_build_deps_joined},${cudnnMerged},${lib.getDev nccl}" 326 build --action_env TF_CUDA_VERSION="${cudaMajorMinorVersion}" 327 build --action_env TF_CUDNN_VERSION="${lib.versions.major cudaPackages.cudnn.version}" 328 build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="${builtins.concatStringsSep "," flags.realArches}" 329 '' 330 + 331 # Note that upstream conditions this on `wheel_cpu == "x86_64"`. We just 332 # rely on `effectiveStdenv.hostPlatform.avxSupport` instead. So far so 333 # good. See https://github.com/google/jax/blob/b9824d7de3cb30f1df738cc42e486db3e9d915ff/build/build.py#L322 334 # for upstream's version. 335 lib.optionalString (effectiveStdenv.hostPlatform.avxSupport && effectiveStdenv.hostPlatform.isUnix) 336 '' 337 build --config=avx_posix 338 '' 339 + lib.optionalString mklSupport '' 340 build --config=mkl_open_source_only 341 '' 342 + '' 343 CFG 344 ''; 345 346 # Make sure Bazel knows about our configuration flags during fetching so that the 347 # relevant dependencies can be downloaded. 348 bazelFlags = [ 349 "-c opt" 350 # See https://bazel.build/external/advanced#overriding-repositories for 351 # information on --override_repository flag. 352 "--override_repository=xla=${xla}" 353 ] 354 ++ lib.optionals effectiveStdenv.cc.isClang [ 355 # bazel depends on the compiler frontend automatically selecting these flags based on file 356 # extension but our clang doesn't. 357 # https://github.com/NixOS/nixpkgs/issues/150655 358 "--cxxopt=-x" 359 "--cxxopt=c++" 360 "--host_cxxopt=-x" 361 "--host_cxxopt=c++" 362 ]; 363 364 # We intentionally overfetch so we can share the fetch derivation across all the different configurations 365 fetchAttrs = { 366 TF_SYSTEM_LIBS = lib.concatStringsSep "," tf_system_libs; 367 # we have to force @mkl_dnn_v1 since it's not needed on darwin 368 bazelTargets = [ 369 bazelRunTarget 370 "@mkl_dnn_v1//:mkl_dnn" 371 ]; 372 bazelFlags = 373 bazelFlags 374 ++ [ 375 "--config=avx_posix" 376 "--config=mkl_open_source_only" 377 ] 378 ++ lib.optionals cudaSupport [ 379 # ideally we'd add this unconditionally too, but it doesn't work on darwin 380 # we make this conditional on `cudaSupport` instead of the system, so that the hash for both 381 # the cuda and the non-cuda deps can be computed on linux, since a lot of contributors don't 382 # have access to darwin machines 383 "--config=cuda" 384 ]; 385 386 sha256 = 387 ( 388 if cudaSupport then 389 { x86_64-linux = "sha256-Uf0VMRE0jgaWEYiuphWkWloZ5jMeqaWBl3lSvk2y1HI="; } 390 else 391 { 392 x86_64-linux = "sha256-NzJJg6NlrPGMiR8Fn8u4+fu0m+AulfmN5Xqk63Um6sw="; 393 aarch64-linux = "sha256-Ro3qzrUxSR+3TH6ROoJTq+dLSufrDN/9oEo2MRkx7wM="; 394 } 395 ).${effectiveStdenv.system} or (throw "jaxlib: unsupported system: ${effectiveStdenv.system}"); 396 397 # Non-reproducible fetch https://github.com/NixOS/nixpkgs/issues/321920#issuecomment-2184940546 398 preInstall = '' 399 cat << \EOF > "$bazelOut/external/go_sdk/versions.json" 400 [] 401 EOF 402 ''; 403 }; 404 405 buildAttrs = { 406 outputs = [ "out" ]; 407 408 TF_SYSTEM_LIBS = lib.concatStringsSep "," ( 409 tf_system_libs 410 ++ lib.optionals (!effectiveStdenv.hostPlatform.isDarwin) [ 411 "nsync" # fails to build on darwin 412 ] 413 ); 414 415 # Note: we cannot do most of this patching at `patch` phase as the deps 416 # are not available yet. 417 preBuild = lib.optionalString effectiveStdenv.hostPlatform.isDarwin '' 418 substituteInPlace ../output/external/rules_cc/cc/private/toolchain/osx_cc_wrapper.sh.tpl \ 419 --replace "/usr/bin/install_name_tool" "${cctools}/bin/install_name_tool" 420 substituteInPlace ../output/external/rules_cc/cc/private/toolchain/unix_cc_configure.bzl \ 421 --replace "/usr/bin/libtool" "${cctools}/bin/libtool" 422 ''; 423 }; 424 425 inherit meta; 426 }; 427 platformTag = 428 if effectiveStdenv.hostPlatform.isLinux then 429 "manylinux2014_${arch}" 430 else if effectiveStdenv.system == "x86_64-darwin" then 431 "macosx_10_9_${arch}" 432 else if effectiveStdenv.system == "aarch64-darwin" then 433 "macosx_11_0_${arch}" 434 else 435 throw "Unsupported target platform: ${effectiveStdenv.hostPlatform}"; 436in 437buildPythonPackage { 438 inherit pname version; 439 format = "wheel"; 440 441 src = 442 let 443 cp = "cp${builtins.replaceStrings [ "." ] [ "" ] python.pythonVersion}"; 444 in 445 "${bazel-build}/jaxlib-${version}-${cp}-${cp}-${platformTag}.whl"; 446 447 # Note that jaxlib looks for "ptxas" in $PATH. See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 448 # for more info. 449 postInstall = lib.optionalString cudaSupport '' 450 mkdir -p $out/bin 451 ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/bin/ptxas 452 453 find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do 454 patchelf --add-rpath "${ 455 lib.makeLibraryPath [ 456 cuda_libs_joined 457 (lib.getLib cudaPackages.cudnn) 458 nccl 459 ] 460 }" "$lib" 461 done 462 ''; 463 464 nativeBuildInputs = lib.optionals cudaSupport [ autoAddDriverRunpath ]; 465 466 dependencies = [ 467 absl-py 468 curl 469 double-conversion 470 flatbuffers 471 giflib 472 jsoncpp 473 libjpeg_turbo 474 ml-dtypes 475 numpy 476 scipy 477 six 478 ]; 479 480 buildInputs = [ 481 snappy-cpp 482 ]; 483 484 pythonImportsCheck = [ 485 "jaxlib" 486 # `import jaxlib` loads surprisingly little. These imports are actually bugs that appeared in the 0.4.11 upgrade. 487 "jaxlib.cpu_feature_guard" 488 "jaxlib.xla_client" 489 ]; 490 491 # Without it there are complaints about libcudart.so.11.0 not being found 492 # because RPATH path entries added above are stripped. 493 dontPatchELF = cudaSupport; 494 495 passthru = { 496 # Note "bazel.*.tar.gz" can be accessed as `jaxlib.bazel-build.deps` 497 inherit bazel-build; 498 }; 499 500 inherit meta; 501}