1{
2 stdenv,
3 #bazel_5,
4 bazel,
5 buildBazelPackage,
6 lib,
7 fetchFromGitHub,
8 symlinkJoin,
9 addDriverRunpath,
10 fetchpatch,
11 fetchzip,
12 linkFarm,
13 # Python deps
14 buildPythonPackage,
15 pythonAtLeast,
16 pythonOlder,
17 python,
18 # Python libraries
19 numpy,
20 tensorboard,
21 abseil-cpp,
22 absl-py,
23 packaging,
24 setuptools,
25 wheel,
26 google-pasta,
27 opt-einsum,
28 astunparse,
29 h5py,
30 termcolor,
31 grpcio,
32 six,
33 wrapt,
34 protobuf-python,
35 tensorflow-estimator-bin,
36 dill,
37 flatbuffers-python,
38 portpicker,
39 tblib,
40 typing-extensions,
41 # Common deps
42 git,
43 pybind11,
44 which,
45 binutils,
46 glibcLocales,
47 cython,
48 perl,
49 # Common libraries
50 jemalloc,
51 mpi,
52 gast,
53 grpc,
54 sqlite,
55 boringssl,
56 jsoncpp,
57 nsync,
58 curl,
59 snappy-cpp,
60 flatbuffers-core,
61 icu,
62 double-conversion,
63 libpng,
64 libjpeg_turbo,
65 giflib,
66 protobuf-core,
67 # Upstream by default includes cuda support since tensorflow 1.15. We could do
68 # that in nix as well. It would make some things easier and less confusing, but
69 # it would also make the default tensorflow package unfree. See
70 # https://groups.google.com/a/tensorflow.org/forum/#!topic/developers/iRCt5m4qUz0
71 config,
72 cudaSupport ? config.cudaSupport,
73 cudaPackages,
74 cudaCapabilities ? cudaPackages.flags.cudaCapabilities,
75 mklSupport ? false,
76 mkl,
77 tensorboardSupport ? true,
78 # XLA without CUDA is broken
79 xlaSupport ? cudaSupport,
80 sse42Support ? stdenv.hostPlatform.sse4_2Support,
81 avx2Support ? stdenv.hostPlatform.avx2Support,
82 fmaSupport ? stdenv.hostPlatform.fmaSupport,
83 cctools,
84 llvmPackages,
85}:
86
87let
88 originalStdenv = stdenv;
89in
90let
91 # Tensorflow looks at many toolchain-related variables which may diverge.
92 #
93 # Toolchain for cuda-enabled builds.
94 # We want to achieve two things:
95 # 1. NVCC should use a compatible back-end (e.g. gcc11 for cuda11)
96 # 2. Normal C++ files should be compiled with the same toolchain,
97 # to avoid potential weird dynamic linkage errors at runtime.
98 # This may not be necessary though
99 #
100 # Toolchain for Darwin:
101 # clang 7 fails to emit a symbol for
102 # __ZN4llvm11SmallPtrSetIPKNS_10AllocaInstELj8EED1Ev in any of the
103 # translation units, so the build fails at link time
104 stdenv =
105 if cudaSupport then
106 cudaPackages.backendStdenv
107 else if originalStdenv.hostPlatform.isDarwin then
108 llvmPackages.stdenv
109 else
110 originalStdenv;
111 inherit (cudaPackages) cudatoolkit nccl;
112 # use compatible cuDNN (https://www.tensorflow.org/install/source#gpu)
113 # cudaPackages.cudnn led to this:
114 # https://github.com/tensorflow/tensorflow/issues/60398
115 #cudnnAttribute = "cudnn_8_6";
116 cudnnAttribute = "cudnn";
117 cudnnMerged = symlinkJoin {
118 name = "cudnn-merged";
119 paths = [
120 (lib.getDev cudaPackages.${cudnnAttribute})
121 (lib.getLib cudaPackages.${cudnnAttribute})
122 ];
123 };
124 gentoo-patches = fetchzip {
125 url = "https://dev.gentoo.org/~perfinion/patches/tensorflow-patches-2.12.0.tar.bz2";
126 hash = "sha256-SCRX/5/zML7LmKEPJkcM5Tebez9vv/gmE4xhT/jyqWs=";
127 };
128 protobuf-extra = linkFarm "protobuf-extra" [
129 {
130 name = "include";
131 path = protobuf-core.src;
132 }
133 ];
134
135 withTensorboard = (pythonOlder "3.6") || tensorboardSupport;
136
137 cudaComponents = with cudaPackages; [
138 (cuda_nvcc.__spliced.buildHost or cuda_nvcc)
139 (cuda_nvprune.__spliced.buildHost or cuda_nvprune)
140 cuda_cccl # block_load.cuh
141 cuda_cudart # cuda.h
142 cuda_cupti # cupti.h
143 cuda_nvcc # See https://github.com/google/jax/issues/19811
144 cuda_nvml_dev # nvml.h
145 cuda_nvtx # nvToolsExt.h
146 libcublas # cublas_api.h
147 libcufft # cufft.h
148 libcurand # curand.h
149 libcusolver # cusolver_common.h
150 libcusparse # cusparse.h
151 ];
152
153 cudatoolkitDevMerged = symlinkJoin {
154 name = "cuda-${cudaPackages.cudaMajorMinorVersion}-dev-merged";
155 paths = lib.concatMap (p: [
156 (lib.getBin p)
157 (lib.getDev p)
158 (lib.getLib p)
159 (lib.getOutput "static" p) # Makes for a very fat closure
160 ]) cudaComponents;
161 };
162
163 # Tensorflow expects bintools at hard-coded paths, e.g. /usr/bin/ar
164 # The only way to overcome that is to set GCC_HOST_COMPILER_PREFIX,
165 # but that path must contain cc as well, so we merge them
166 cudatoolkit_cc_joined = symlinkJoin {
167 name = "${stdenv.cc.name}-merged";
168 paths = [
169 stdenv.cc
170 binutils.bintools # for ar, dwp, nm, objcopy, objdump, strip
171 ];
172 };
173
174 # Needed for _some_ system libraries, grep INCLUDEDIR.
175 includes_joined = symlinkJoin {
176 name = "tensorflow-deps-merged";
177 paths = [ jsoncpp ];
178 };
179
180 tfFeature = x: if x then "1" else "0";
181
182 version = "2.13.0";
183 format = "setuptools";
184 variant = lib.optionalString cudaSupport "-gpu";
185 pname = "tensorflow${variant}";
186
187 pythonEnv = python.withPackages (_: [
188 # python deps needed during wheel build time (not runtime, see the buildPythonPackage part for that)
189 # This list can likely be shortened, but each trial takes multiple hours so won't bother for now.
190 absl-py
191 astunparse
192 dill
193 flatbuffers-python
194 gast
195 google-pasta
196 grpcio
197 h5py
198 numpy
199 opt-einsum
200 packaging
201 protobuf-python
202 setuptools
203 six
204 tblib
205 tensorboard
206 tensorflow-estimator-bin
207 termcolor
208 typing-extensions
209 wheel
210 wrapt
211 ]);
212
213 rules_cc_darwin_patched = stdenv.mkDerivation {
214 name = "rules_cc-${pname}-${version}";
215
216 src = _bazel-build.deps;
217
218 prePatch = "pushd rules_cc";
219 patches = [
220 # https://github.com/bazelbuild/rules_cc/issues/122
221 (fetchpatch {
222 name = "tensorflow-rules_cc-libtool-path.patch";
223 url = "https://github.com/bazelbuild/rules_cc/commit/8c427ab30bf213630dc3bce9d2e9a0e29d1787db.diff";
224 hash = "sha256-C4v6HY5+jm0ACUZ58gBPVejCYCZfuzYKlHZ0m2qDHCk=";
225 })
226
227 # https://github.com/bazelbuild/rules_cc/pull/124
228 (fetchpatch {
229 name = "tensorflow-rules_cc-install_name_tool-path.patch";
230 url = "https://github.com/bazelbuild/rules_cc/commit/156497dc89100db8a3f57b23c63724759d431d05.diff";
231 hash = "sha256-NES1KeQmMiUJQVoV6dS4YGRxxkZEjOpFSCyOq9HZYO0=";
232 })
233 ];
234 postPatch = "popd";
235
236 dontConfigure = true;
237 dontBuild = true;
238
239 installPhase = ''
240 runHook preInstall
241
242 mv rules_cc/ "$out"
243
244 runHook postInstall
245 '';
246 };
247 llvm-raw_darwin_patched = stdenv.mkDerivation {
248 name = "llvm-raw-${pname}-${version}";
249
250 src = _bazel-build.deps;
251
252 prePatch = "pushd llvm-raw";
253 patches = [
254 # Fix a vendored config.h that requires the 10.13 SDK
255 ./llvm_bazel_fix_macos_10_12_sdk.patch
256 ];
257 postPatch = ''
258 touch {BUILD,WORKSPACE}
259 popd
260 '';
261
262 dontConfigure = true;
263 dontBuild = true;
264
265 installPhase = ''
266 runHook preInstall
267
268 mv llvm-raw/ "$out"
269
270 runHook postInstall
271 '';
272 };
273 bazel-build =
274 if stdenv.hostPlatform.isDarwin then
275 _bazel-build.overrideAttrs (prev: {
276 bazelFlags = prev.bazelFlags ++ [
277 "--override_repository=rules_cc=${rules_cc_darwin_patched}"
278 "--override_repository=llvm-raw=${llvm-raw_darwin_patched}"
279 ];
280 preBuild = ''
281 export AR="${cctools}/bin/libtool"
282 '';
283 })
284 else
285 _bazel-build;
286
287 _bazel-build = buildBazelPackage.override { inherit stdenv; } {
288 name = "${pname}-${version}";
289 #bazel = bazel_5;
290 bazel = bazel;
291
292 src = fetchFromGitHub {
293 owner = "tensorflow";
294 repo = "tensorflow";
295 tag = "v${version}";
296 hash = "sha256-Rq5pAVmxlWBVnph20fkAwbfy+iuBNlfFy14poDPd5h0=";
297 };
298
299 # On update, it can be useful to steal the changes from gentoo
300 # https://gitweb.gentoo.org/repo/gentoo.git/tree/sci-libs/tensorflow
301
302 nativeBuildInputs = [
303 which
304 pythonEnv
305 cython
306 perl
307 protobuf-core
308 protobuf-extra
309 ]
310 ++ lib.optional cudaSupport addDriverRunpath;
311
312 buildInputs = [
313 jemalloc
314 mpi
315 glibcLocales
316 git
317
318 # libs taken from system through the TF_SYS_LIBS mechanism
319 abseil-cpp
320 boringssl
321 curl
322 double-conversion
323 flatbuffers-core
324 giflib
325 grpc
326 # Necessary to fix the "`GLIBCXX_3.4.30' not found" error
327 (icu.override { inherit stdenv; })
328 jsoncpp
329 libjpeg_turbo
330 libpng
331 (pybind11.overridePythonAttrs (_: {
332 inherit stdenv;
333 }))
334 snappy-cpp
335 sqlite
336 ]
337 ++ lib.optionals cudaSupport [
338 cudatoolkit
339 cudnnMerged
340 ]
341 ++ lib.optionals mklSupport [ mkl ]
342 ++ lib.optionals (!stdenv.hostPlatform.isDarwin) [ nsync ];
343
344 # arbitrarily set to the current latest bazel version, overly careful
345 TF_IGNORE_MAX_BAZEL_VERSION = true;
346
347 LIBTOOL = lib.optionalString stdenv.hostPlatform.isDarwin "${cctools}/bin/libtool";
348
349 # Take as many libraries from the system as possible. Keep in sync with
350 # list of valid syslibs in
351 # https://github.com/tensorflow/tensorflow/blob/master/third_party/systemlibs/syslibs_configure.bzl
352 TF_SYSTEM_LIBS = lib.concatStringsSep "," (
353 [
354 "absl_py"
355 "astor_archive"
356 "astunparse_archive"
357 "boringssl"
358 "com_google_absl"
359 # Not packaged in nixpkgs
360 # "com_github_googleapis_googleapis"
361 # "com_github_googlecloudplatform_google_cloud_cpp"
362 "com_github_grpc_grpc"
363 "com_google_protobuf"
364 # Fails with the error: external/org_tensorflow/tensorflow/core/profiler/utils/tf_op_utils.cc:46:49: error: no matching function for call to 're2::RE2::FullMatch(absl::lts_2020_02_25::string_view&, re2::RE2&)'
365 # "com_googlesource_code_re2"
366 "curl"
367 "cython"
368 "dill_archive"
369 "double_conversion"
370 "flatbuffers"
371 "functools32_archive"
372 "gast_archive"
373 "gif"
374 "hwloc"
375 "icu"
376 "jsoncpp_git"
377 "libjpeg_turbo"
378 "nasm"
379 "opt_einsum_archive"
380 "org_sqlite"
381 "pasta"
382 "png"
383 "pybind11"
384 "six_archive"
385 "snappy"
386 "tblib_archive"
387 "termcolor_archive"
388 "typing_extensions_archive"
389 "wrapt"
390 "zlib"
391 ]
392 ++ lib.optionals (!stdenv.hostPlatform.isDarwin) [
393 "nsync" # fails to build on darwin
394 ]
395 );
396
397 INCLUDEDIR = "${includes_joined}/include";
398
399 # This is needed for the Nix-provided protobuf dependency to work,
400 # as otherwise the rule `link_proto_files` tries to create the links
401 # to `/usr/include/...` which results in build failures.
402 PROTOBUF_INCLUDE_PATH = "${protobuf-core}/include";
403
404 PYTHON_BIN_PATH = pythonEnv.interpreter;
405
406 TF_NEED_GCP = true;
407 TF_NEED_HDFS = true;
408 TF_ENABLE_XLA = tfFeature xlaSupport;
409
410 CC_OPT_FLAGS = " ";
411
412 # https://github.com/tensorflow/tensorflow/issues/14454
413 TF_NEED_MPI = tfFeature cudaSupport;
414
415 TF_NEED_CUDA = tfFeature cudaSupport;
416 TF_CUDA_PATHS = lib.optionalString cudaSupport "${cudatoolkitDevMerged},${cudnnMerged},${lib.getLib nccl}";
417 TF_CUDA_COMPUTE_CAPABILITIES = lib.concatStringsSep "," cudaCapabilities;
418
419 # Needed even when we override stdenv: e.g. for ar
420 GCC_HOST_COMPILER_PREFIX = lib.optionalString cudaSupport "${cudatoolkit_cc_joined}/bin";
421 GCC_HOST_COMPILER_PATH = lib.optionalString cudaSupport "${cudatoolkit_cc_joined}/bin/cc";
422
423 patches = [
424 "${gentoo-patches}/0002-systemlib-Latest-absl-LTS-has-split-cord-libs.patch"
425 "${gentoo-patches}/0005-systemlib-Updates-for-Abseil-20220623-LTS.patch"
426 "${gentoo-patches}/0007-systemlibs-Add-well_known_types_py_pb2-target.patch"
427 # https://github.com/conda-forge/tensorflow-feedstock/pull/329/commits/0a63c5a962451b4da99a9948323d8b3ed462f461
428 (fetchpatch {
429 name = "fix-layout-proto-duplicate-loading.patch";
430 url = "https://raw.githubusercontent.com/conda-forge/tensorflow-feedstock/0a63c5a962451b4da99a9948323d8b3ed462f461/recipe/patches/0001-Omit-linking-to-layout_proto_cc-if-protobuf-linkage-.patch";
431 hash = "sha256-/7buV6DinKnrgfqbe7KKSh9rCebeQdXv2Uj+Xg/083w=";
432 })
433 ./com_google_absl_add_log.patch
434 ./absl_py_argparse_flags.patch
435 ./protobuf_python.patch
436 ./pybind11_protobuf_python_runtime_dep.patch
437 ./pybind11_protobuf_newer_version.patch
438 ]
439 ++ lib.optionals (stdenv.hostPlatform.system == "aarch64-darwin") [ ./absl_to_std.patch ];
440
441 postPatch = ''
442 # bazel 3.3 should work just as well as bazel 3.1
443 rm -f .bazelversion
444 patchShebangs .
445 ''
446 + lib.optionalString (!withTensorboard) ''
447 # Tensorboard pulls in a bunch of dependencies, some of which may
448 # include security vulnerabilities. So we make it optional.
449 # https://github.com/tensorflow/tensorflow/issues/20280#issuecomment-400230560
450 sed -i '/tensorboard ~=/d' tensorflow/tools/pip_package/setup.py
451 '';
452
453 # https://github.com/tensorflow/tensorflow/pull/39470
454 env.NIX_CFLAGS_COMPILE = toString [ "-Wno-stringop-truncation" ];
455
456 preConfigure =
457 let
458 opt_flags =
459 [ ]
460 ++ lib.optionals sse42Support [ "-msse4.2" ]
461 ++ lib.optionals avx2Support [ "-mavx2" ]
462 ++ lib.optionals fmaSupport [ "-mfma" ];
463 in
464 ''
465 patchShebangs configure
466
467 # dummy ldconfig
468 mkdir dummy-ldconfig
469 echo "#!${stdenv.shell}" > dummy-ldconfig/ldconfig
470 chmod +x dummy-ldconfig/ldconfig
471 export PATH="$PWD/dummy-ldconfig:$PATH"
472
473 export PYTHON_LIB_PATH="$NIX_BUILD_TOP/site-packages"
474 export CC_OPT_FLAGS="${lib.concatStringsSep " " opt_flags}"
475 mkdir -p "$PYTHON_LIB_PATH"
476
477 # To avoid mixing Python 2 and Python 3
478 unset PYTHONPATH
479 '';
480
481 configurePhase = ''
482 runHook preConfigure
483 ./configure
484 runHook postConfigure
485 '';
486
487 hardeningDisable = [ "format" ];
488
489 bazelBuildFlags = [
490 "--config=opt" # optimize using the flags set in the configure phase
491 ]
492 ++ lib.optionals stdenv.cc.isClang [
493 "--cxxopt=-x"
494 "--cxxopt=c++"
495 "--host_cxxopt=-x"
496 "--host_cxxopt=c++"
497
498 # workaround for https://github.com/bazelbuild/bazel/issues/15359
499 "--spawn_strategy=sandboxed"
500 ]
501 ++ lib.optionals (mklSupport) [ "--config=mkl" ];
502
503 bazelTargets = [
504 "//tensorflow/tools/pip_package:build_pip_package //tensorflow/tools/lib_package:libtensorflow"
505 ];
506
507 removeRulesCC = false;
508 # Without this Bazel complaints about sandbox violations.
509 dontAddBazelOpts = true;
510
511 fetchAttrs = {
512 sha256 =
513 {
514 x86_64-linux =
515 if cudaSupport then
516 "sha256-5VFMNHeLrUxW5RTr6EhT3pay9nWJ5JkZTGirDds5QkU="
517 else
518 "sha256-KzgWV69Btr84FdwQ5JI2nQEsqiPg1/+TWdbw5bmxXOE=";
519 aarch64-linux =
520 if cudaSupport then
521 "sha256-ty5+51BwHWE1xR4/0WcWTp608NzSAS/iiyN+9zx7/wI="
522 else
523 "sha256-9btXrNHqd720oXTPDhSmFidv5iaZRLjCVX8opmrMjXk=";
524 x86_64-darwin = "sha256-gqb03kB0z2pZQ6m1fyRp1/Nbt8AVVHWpOJSeZNCLc4w=";
525 aarch64-darwin = "sha256-WdgAaFZU+ePwWkVBhLzjlNT7ELfGHOTaMdafcAMD5yo=";
526 }
527 .${stdenv.hostPlatform.system} or (throw "unsupported system ${stdenv.hostPlatform.system}");
528 };
529
530 buildAttrs = {
531 outputs = [
532 "out"
533 "python"
534 ];
535
536 # need to rebuild schemas since we use a different flatbuffers version
537 preBuild = ''
538 (cd tensorflow/lite/schema;${flatbuffers-core}/bin/flatc --gen-object-api -c schema.fbs)
539 (cd tensorflow/lite/schema;${flatbuffers-core}/bin/flatc --gen-object-api -c conversion_metadata.fbs)
540 (cd tensorflow/lite/acceleration/configuration;${flatbuffers-core}/bin/flatc -o configuration.fbs --proto configuration.proto)
541 sed -i s,tflite.proto,tflite,g tensorflow/lite/acceleration/configuration/configuration.fbs/configuration.fbs
542 (cd tensorflow/lite/acceleration/configuration;${flatbuffers-core}/bin/flatc --gen-compare --gen-object-api -c configuration.fbs/configuration.fbs)
543 cp -r tensorflow/lite/acceleration/configuration/configuration.fbs tensorflow/lite/experimental/acceleration/configuration
544 (cd tensorflow/lite/experimental/acceleration/configuration;${flatbuffers-core}/bin/flatc -c configuration.fbs/configuration.fbs)
545 (cd tensorflow/lite/delegates/gpu/cl;${flatbuffers-core}/bin/flatc -c compiled_program_cache.fbs)
546 (cd tensorflow/lite/delegates/gpu/cl;${flatbuffers-core}/bin/flatc -I $NIX_BUILD_TOP/source -c serialization.fbs)
547 (cd tensorflow/lite/delegates/gpu/common;${flatbuffers-core}/bin/flatc -I $NIX_BUILD_TOP/source -c gpu_model.fbs)
548 (cd tensorflow/lite/delegates/gpu/common/task;${flatbuffers-core}/bin/flatc -c serialization_base.fbs)
549 patchShebangs .
550 '';
551
552 installPhase = ''
553 mkdir -p "$out"
554 tar -xf bazel-bin/tensorflow/tools/lib_package/libtensorflow.tar.gz -C "$out"
555 # Write pkgconfig file.
556 mkdir "$out/lib/pkgconfig"
557 cat > "$out/lib/pkgconfig/tensorflow.pc" << EOF
558 Name: TensorFlow
559 Version: ${version}
560 Description: Library for computation using data flow graphs for scalable machine learning
561 Requires:
562 Libs: -L$out/lib -ltensorflow
563 Cflags: -I$out/include/tensorflow
564 EOF
565
566 # build the source code, then copy it to $python (build_pip_package
567 # actually builds a symlink farm so we must dereference them).
568 bazel-bin/tensorflow/tools/pip_package/build_pip_package --src "$PWD/dist"
569 cp -Lr "$PWD/dist" "$python"
570 '';
571
572 postFixup = lib.optionalString cudaSupport ''
573 find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do
574 addDriverRunpath "$lib"
575 done
576 '';
577
578 requiredSystemFeatures = [ "big-parallel" ];
579 };
580
581 meta = {
582 badPlatforms = lib.optionals cudaSupport lib.platforms.darwin;
583 changelog = "https://github.com/tensorflow/tensorflow/releases/tag/v${version}";
584 description = "Computation using data flow graphs for scalable machine learning";
585 homepage = "http://tensorflow.org";
586 license = lib.licenses.asl20;
587 maintainers = [ ];
588 platforms = with lib.platforms; linux ++ darwin;
589 broken =
590 # Dependencies are EOL and have been removed; an update
591 # to a newer TensorFlow version will be required to fix the
592 # source build.
593 true
594 || stdenv.hostPlatform.isDarwin
595 || !(xlaSupport -> cudaSupport)
596 || !(cudaSupport -> builtins.hasAttr cudnnAttribute cudaPackages)
597 || !(cudaSupport -> cudaPackages ? cudatoolkit);
598 }
599 // lib.optionalAttrs stdenv.hostPlatform.isDarwin {
600 timeout = 86400; # 24 hours
601 maxSilent = 14400; # 4h, double the default of 7200s
602 };
603 };
604in
605buildPythonPackage {
606 __structuredAttrs = true;
607 inherit version pname format;
608 disabled = pythonAtLeast "3.13";
609
610 src = bazel-build.python;
611
612 # Adjust dependency requirements:
613 # - Drop tensorflow-io dependency until we get it to build
614 # - Relax flatbuffers and gast version requirements
615 # - The purpose of python3Packages.libclang is not clear at the moment and we don't have it packaged yet
616 # - keras will be considered as optional for now.
617 postPatch = ''
618 sed -i setup.py \
619 -e '/tensorflow-io-gcs-filesystem/,+1d' \
620 -e "s/'flatbuffers[^']*',/'flatbuffers',/" \
621 -e "s/'gast[^']*',/'gast',/" \
622 -e "/'libclang[^']*',/d" \
623 -e "/'keras[^']*')\?,/d" \
624 -e "s/'protobuf[^']*',/'protobuf',/" \
625 '';
626
627 # Upstream has a pip hack that results in bin/tensorboard being in both tensorflow
628 # and the propagated input tensorboard, which causes environment collisions.
629 # Another possibility would be to have tensorboard only in the buildInputs
630 # https://github.com/tensorflow/tensorflow/blob/v1.7.1/tensorflow/tools/pip_package/setup.py#L79
631 postInstall = ''
632 rm $out/bin/tensorboard
633 '';
634
635 setupPyGlobalFlags = [
636 "--project_name"
637 pname
638 ];
639
640 # tensorflow/tools/pip_package/setup.py
641 propagatedBuildInputs = [
642 absl-py
643 abseil-cpp
644 astunparse
645 flatbuffers-python
646 gast
647 google-pasta
648 grpcio
649 h5py
650 numpy
651 opt-einsum
652 packaging
653 protobuf-python
654 six
655 tensorflow-estimator-bin
656 termcolor
657 typing-extensions
658 wrapt
659 ]
660 ++ lib.optionals withTensorboard [ tensorboard ];
661
662 nativeBuildInputs = lib.optionals cudaSupport [ addDriverRunpath ];
663
664 postFixup = lib.optionalString cudaSupport ''
665 find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do
666 addDriverRunpath "$lib"
667
668 patchelf --set-rpath "${cudatoolkit}/lib:${cudatoolkit.lib}/lib:${cudnnMerged}/lib:${lib.getLib nccl}/lib:$(patchelf --print-rpath "$lib")" "$lib"
669 done
670 '';
671
672 # Actual tests are slow and impure.
673 # TODO try to run them anyway
674 # TODO better test (files in tensorflow/tools/ci_build/builds/*test)
675 # TEST_PACKAGES in tensorflow/tools/pip_package/setup.py
676 nativeCheckInputs = [
677 dill
678 portpicker
679 tblib
680 ];
681 checkPhase = ''
682 ${python.interpreter} <<EOF
683 # A simple "Hello world"
684 import tensorflow as tf
685 hello = tf.constant("Hello, world!")
686 tf.print(hello)
687
688 tf.random.set_seed(0)
689 width = 512
690 choice = 48
691 t_in = tf.Variable(tf.random.uniform(shape=[width]))
692 with tf.GradientTape() as tape:
693 t_out = tf.slice(tf.nn.softmax(t_in), [choice], [1])
694 diff = tape.gradient(t_out, t_in)
695 assert(0 < tf.reduce_min(tf.slice(diff, [choice], [1])))
696 assert(0 > tf.reduce_max(tf.slice(diff, [1], [choice - 1])))
697 EOF
698 '';
699 # Regression test for #77626 removed because not more `tensorflow.contrib`.
700
701 passthru = {
702 deps = bazel-build.deps;
703 libtensorflow = bazel-build.out;
704 };
705
706 inherit (bazel-build) meta;
707}