1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5 pandas,
6 pytestCheckHook,
7 setuptools,
8 tensorboard,
9 torch,
10 torchvision,
11}:
12let
13 version = "0.4.0";
14 repo = fetchFromGitHub {
15 owner = "pytorch";
16 repo = "kineto";
17 rev = "refs/tags/v${version}";
18 hash = "sha256-nAtqGCv8q3Tati3NOGWWLb+gXdvO3qmECeC1WG2Mt3M=";
19 };
20in
21buildPythonPackage {
22 pname = "torch_tb_profiler";
23 inherit version;
24 pyproject = true;
25
26 # See https://discourse.nixos.org/t/extracting-sub-directory-from-fetchgit-or-fetchurl-or-any-derivation/8830.
27 src = "${repo}/tb_plugin";
28
29 build-system = [ setuptools ];
30
31 dependencies = [
32 pandas
33 tensorboard
34 ];
35
36 nativeCheckInputs = [
37 pytestCheckHook
38 torch
39 torchvision
40 ];
41
42 disabledTests = [
43 # Tests that attempt to access the filesystem in naughty ways.
44 "test_profiler_api_without_gpu"
45 "test_tensorboard_end2end"
46 "test_tensorboard_with_path_prefix"
47 "test_tensorboard_with_symlinks"
48 "test_autograd_api"
49 "test_profiler_api_with_record_shapes_memory_stack"
50 "test_profiler_api_without_record_shapes_memory_stack"
51 "test_profiler_api_without_step"
52 ];
53
54 pythonImportsCheck = [ "torch_tb_profiler" ];
55
56 meta = {
57 description = "PyTorch Profiler TensorBoard Plugin";
58 homepage = "https://github.com/pytorch/kineto";
59 license = lib.licenses.bsd3;
60 maintainers = with lib.maintainers; [ samuela ];
61 };
62}