1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5 pytestCheckHook,
6 pythonOlder,
7 writeText,
8 catboost,
9 cloudpickle,
10 cython,
11 ipython,
12 lightgbm,
13 lime,
14 matplotlib,
15 numba,
16 numpy,
17 opencv4,
18 pandas,
19 pyspark,
20 pytest-mpl,
21 scikit-learn,
22 scipy,
23 sentencepiece,
24 setuptools,
25 setuptools-scm,
26 slicer,
27 tqdm,
28 transformers,
29 xgboost,
30}:
31
32buildPythonPackage rec {
33 pname = "shap";
34 version = "0.48.0";
35 pyproject = true;
36
37 disabled = pythonOlder "3.8";
38
39 src = fetchFromGitHub {
40 owner = "slundberg";
41 repo = "shap";
42 tag = "v${version}";
43 hash = "sha256-eWZhyrFpEFlmTFPTHZng9V+uMRMXDVzFdgrqIzRQTws=";
44 };
45
46 postPatch = ''
47 substituteInPlace pyproject.toml \
48 --replace-fail "cython>=3.0.11" cython \
49 --replace-fail "numpy>=2.0" "numpy"
50 '';
51
52 build-system = [
53 cython
54 numpy
55 setuptools
56 setuptools-scm
57 ];
58
59 dependencies = [
60 cloudpickle
61 numba
62 numpy
63 pandas
64 scikit-learn
65 scipy
66 slicer
67 tqdm
68 ];
69
70 optional-dependencies = {
71 plots = [
72 matplotlib
73 ipython
74 ];
75 others = [ lime ];
76 };
77
78 preCheck =
79 let
80 # This pytest hook mocks and catches attempts at accessing the network
81 # tests that try to access the network will raise, get caught, be marked as skipped and tagged as xfailed.
82 conftestSkipNetworkErrors = writeText "conftest.py" ''
83 from _pytest.runner import pytest_runtest_makereport as orig_pytest_runtest_makereport
84 import urllib, requests, transformers
85
86 class NetworkAccessDeniedError(RuntimeError): pass
87 def deny_network_access(*a, **kw):
88 raise NetworkAccessDeniedError
89
90 requests.head = deny_network_access
91 requests.get = deny_network_access
92 urllib.request.urlopen = deny_network_access
93 urllib.request.Request = deny_network_access
94 transformers.AutoTokenizer.from_pretrained = deny_network_access
95
96 def pytest_runtest_makereport(item, call):
97 tr = orig_pytest_runtest_makereport(item, call)
98 if call.excinfo is not None and call.excinfo.type is NetworkAccessDeniedError:
99 tr.outcome = 'skipped'
100 tr.wasxfail = "reason: Requires network access."
101 return tr
102 '';
103 in
104 ''
105 export HOME=$TMPDIR
106 # when importing the local copy the extension is not found
107 rm -r shap
108
109 # Add pytest hook skipping tests that access network.
110 # These tests are marked as "Expected fail" (xfail)
111 cat ${conftestSkipNetworkErrors} >> tests/conftest.py
112 '';
113
114 nativeCheckInputs = [
115 ipython
116 matplotlib
117 pytest-mpl
118 pytestCheckHook
119 # optional dependencies, which only serve to enable more tests:
120 catboost
121 lightgbm
122 opencv4
123 pyspark
124 sentencepiece
125 #torch # we already skip all its tests due to slowness, adding it does nothing
126 transformers
127 xgboost
128 ];
129
130 # Test startup hangs with 0.43.0 and Hydra ends with a timeout
131 doCheck = false;
132
133 disabledTestPaths = [
134 # The resulting plots look sane, but does not match pixel-perfectly with the baseline.
135 # Likely due to a matplotlib version mismatch, different backend, or due to missing fonts.
136 "tests/plots/test_summary.py" # FIXME: enable
137 ];
138
139 disabledTests = [
140 # The same reason as above test_summary.py
141 "test_random_force_plot_negative_sign"
142 "test_random_force_plot_positive_sign"
143 "test_random_summary_layered_violin_with_data2"
144 "test_random_summary_violin_with_data2"
145 "test_simple_bar_with_cohorts_dict"
146 ];
147
148 pythonImportsCheck = [ "shap" ];
149
150 meta = with lib; {
151 description = "Unified approach to explain the output of any machine learning model";
152 homepage = "https://github.com/slundberg/shap";
153 changelog = "https://github.com/slundberg/shap/releases/tag/${src.tag}";
154 license = licenses.mit;
155 maintainers = with maintainers; [
156 evax
157 natsukium
158 ];
159 };
160}