1{
2 lib,
3 stdenv,
4 fetchFromGitHub,
5 fetchpatch,
6 buildPythonPackage,
7 pytestCheckHook,
8 setuptools,
9 matplotlib,
10 numpy,
11 packaging,
12 torch,
13 tqdm,
14 flask,
15 flask-compress,
16 parameterized,
17 scikit-learn,
18}:
19
20buildPythonPackage rec {
21 pname = "captum";
22 version = "0.8.0";
23 pyproject = true;
24
25 build-system = [ setuptools ];
26
27 src = fetchFromGitHub {
28 owner = "pytorch";
29 repo = "captum";
30 tag = "v${version}";
31 hash = "sha256-WuKbMYZPHWaTYYhVseSSkwXQk9LBzGuWfmneDw9V2hg=";
32 };
33
34 dependencies = [
35 matplotlib
36 numpy
37 packaging
38 torch
39 tqdm
40 ];
41
42 pythonRelaxDeps = [
43 "numpy"
44 ];
45
46 pythonImportsCheck = [ "captum" ];
47
48 nativeCheckInputs = [
49 pytestCheckHook
50 flask
51 flask-compress
52 parameterized
53 scikit-learn
54 ];
55
56 disabledTestPaths =
57 lib.optionals stdenv.hostPlatform.isDarwin [
58 # These tests may fail if multiple builds run them at the same time due
59 # to hardcoded port number used for rendezvous
60 "tests/attr/test_data_parallel.py"
61 ]
62 ++ lib.optionals (stdenv.hostPlatform.isDarwin && stdenv.hostPlatform.isAarch64) [
63 # Issue reported upstream at https://github.com/pytorch/captum/issues/1447
64 "tests/concept/test_tcav.py"
65 ];
66
67 disabledTests = [
68 # Failing tests
69 "test_softmax_classification_batch_zero_baseline"
70 "test_tracin_identity_regression_9_check_idx_none_ArnoldiInfluenceFunction"
71 ];
72
73 meta = {
74 description = "Model interpretability and understanding for PyTorch";
75 homepage = "https://github.com/pytorch/captum";
76 license = lib.licenses.bsd3;
77 maintainers = with lib.maintainers; [ ];
78 };
79}