1{
2 lib,
3 stdenv,
4 buildPythonPackage,
5 fetchFromGitHub,
6
7 # build-system
8 setuptools,
9
10 # dependencies
11 typing-extensions,
12
13 # tests
14 cython,
15 numpy,
16 pytest-timeout,
17 pytest-xdist,
18 pytestCheckHook,
19 scikit-image,
20 scikit-learn,
21 torchtnt-nightly,
22 torchvision,
23}:
24let
25 pname = "torcheval";
26 version = "0.0.7";
27in
28buildPythonPackage {
29 inherit pname version;
30 pyproject = true;
31
32 src = fetchFromGitHub {
33 owner = "pytorch";
34 repo = "torcheval";
35 # Upstream has not created a tag for this version
36 # https://github.com/pytorch/torcheval/issues/215
37 rev = "f1bc22fc67ec2c77ee519aa4af8079f4fdaa41bb";
38 hash = "sha256-aVr4qKKE+dpBcJEi1qZJBljFLUl8d7D306Dy8uOojJE=";
39 };
40
41 # Patches are only applied to usages of numpy within tests,
42 # which are only used for testing purposes (see dev-requirements.txt)
43 postPatch =
44 # numpy's `np.NAN` was changed to `np.nan` when numpy 2 was released
45 ''
46 substituteInPlace tests/metrics/classification/test_accuracy.py tests/metrics/functional/classification/test_accuracy.py \
47 --replace-fail "np.NAN" "np.nan"
48 ''
49
50 # `unittest.TestCase.assertEquals` does not exist;
51 # the correct symbol is `unittest.TestCase.assertEqual`
52 + ''
53 substituteInPlace tests/metrics/test_synclib.py \
54 --replace-fail "tc.assertEquals" "tc.assertEqual"
55 '';
56
57 build-system = [ setuptools ];
58
59 dependencies = [ typing-extensions ];
60
61 pythonImportsCheck = [ "torcheval" ];
62
63 nativeCheckInputs = [
64 cython
65 numpy
66 pytest-timeout
67 pytest-xdist
68 pytestCheckHook
69 scikit-image
70 scikit-learn
71 torchtnt-nightly
72 torchvision
73 ];
74
75 pytestFlags = [
76 "-v"
77 ];
78
79 enabledTestPaths = [
80 "tests/"
81 ];
82
83 disabledTestPaths = [
84 # -- tests/metrics/audio/test_fad.py --
85 # Touch filesystem and require network access.
86 # torchaudio.utils.download_asset("models/vggish.pt") -> PermissionError: [Errno 13] Permission denied: '/homeless-shelter'
87 "tests/metrics/audio/test_fad.py::TestFAD::test_vggish_fad"
88 "tests/metrics/audio/test_fad.py::TestFAD::test_vggish_fad_merge"
89
90 # -- tests/metrics/image/test_fid.py --
91 # Touch filesystem and require network access.
92 # models.inception_v3(weights=weights) -> PermissionError: [Errno 13] Permission denied: '/homeless-shelter'
93 "tests/metrics/image/test_fid.py::TestFrechetInceptionDistance::test_fid_invalid_input"
94 "tests/metrics/image/test_fid.py::TestFrechetInceptionDistance::test_fid_random_data_custom_model"
95 "tests/metrics/image/test_fid.py::TestFrechetInceptionDistance::test_fid_random_data_default_model"
96 "tests/metrics/image/test_fid.py::TestFrechetInceptionDistance::test_fid_with_dissimilar_inputs"
97 "tests/metrics/image/test_fid.py::TestFrechetInceptionDistance::test_fid_with_similar_inputs"
98
99 # -- tests/metrics/functional/text/test_perplexity.py --
100 # AssertionError: Scalars are not close!
101 # Expected 3.537154912949 but got 3.53715443611145
102 "tests/metrics/functional/text/test_perplexity.py::Perplexity::test_perplexity_with_ignore_index"
103
104 # -- tests/metrics/image/test_psnr.py --
105 # AssertionError: Scalars are not close!
106 # Expected 7.781850814819336 but got 7.781772613525391
107 "tests/metrics/image/test_psnr.py::TestPeakSignalNoiseRatio::test_psnr_with_random_data"
108
109 # -- tests/metrics/regression/test_mean_squared_error.py --
110 # AssertionError: Scalars are not close!
111 # Expected -640.4547729492188 but got -640.4707641601562
112 "tests/metrics/regression/test_mean_squared_error.py::TestMeanSquaredError::test_mean_squared_error_class_update_input_shape_different"
113
114 # -- tests/metrics/window/test_mean_squared_error.py --
115 # AssertionError: Scalars are not close!
116 # Expected 0.0009198983898386359 but got 0.0009198188781738281
117 "tests/metrics/window/test_mean_squared_error.py::TestMeanSquaredError::test_mean_squared_error_class_update_input_shape_different"
118 ]
119
120 # These tests error on darwin platforms.
121 # NotImplementedError: The operator 'c10d::allgather_' is not currently implemented for the mps device
122 #
123 # Applying the suggested environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1;` causes the tests to fail,
124 # as using the CPU instead of the MPS causes the tensors to be on the wrong device:
125 # RuntimeError: ProcessGroupGloo::allgather: invalid tensor type at index 0;
126 # Expected TensorOptions(dtype=float, device=cpu, ...), got TensorOptions(dtype=float, device=mps:0, ...)
127 ++ lib.optional stdenv.hostPlatform.isDarwin [
128 # -- tests/metrics/test_synclib.py --
129 "tests/metrics/test_synclib.py::SynclibTest::test_complex_mixed_state_sync"
130 "tests/metrics/test_synclib.py::SynclibTest::test_complex_mixed_state_sync"
131 "tests/metrics/test_synclib.py::SynclibTest::test_empty_tensor_list_sync_state"
132 "tests/metrics/test_synclib.py::SynclibTest::test_sync_dtype_and_shape"
133 "tests/metrics/test_synclib.py::SynclibTest::test_tensor_list_sync_states"
134 "tests/metrics/test_synclib.py::SynclibTest::test_tensor_dict_sync_states"
135 "tests/metrics/test_synclib.py::SynclibTest::test_tensor_sync_states"
136 # -- tests/metrics/test_toolkit.py --
137 "tests/metrics/test_toolkit.py::MetricToolkitTest::test_metric_sync"
138 "tests/metrics/test_toolkit.py::MetricCollectionToolkitTest::test_metric_collection_sync"
139
140 # Cannot access local process over IPv6 (nodename nor servname provided) even with __darwinAllowLocalNetworking
141 # Will hang, or appear to hang, with an 5 minute (default) timeout per test
142 "tests/metrics/aggregation/test_auc.py"
143 "tests/metrics/aggregation/test_cat.py"
144 "tests/metrics/aggregation/test_max.py"
145 "tests/metrics/aggregation/test_mean.py"
146 "tests/metrics/aggregation/test_min.py"
147 "tests/metrics/aggregation/test_sum.py"
148 "tests/metrics/aggregation/test_throughput.py"
149 "tests/metrics/classification/test_accuracy.py"
150 "tests/metrics/classification/test_auprc.py"
151 "tests/metrics/classification/test_auroc.py"
152 "tests/metrics/classification/test_binned_auprc.py"
153 "tests/metrics/classification/test_binned_auroc.py"
154 "tests/metrics/classification/test_binned_precision_recall_curve.py"
155 "tests/metrics/classification/test_confusion_matrix.py"
156 "tests/metrics/classification/test_f1_score.py"
157 "tests/metrics/classification/test_normalized_entropy.py"
158 "tests/metrics/classification/test_precision_recall_curve.py"
159 "tests/metrics/classification/test_precision.py"
160 "tests/metrics/classification/test_recall_at_fixed_precision.py"
161 "tests/metrics/classification/test_recall.py"
162 "tests/metrics/functional/classification/test_auroc.py"
163 "tests/metrics/ranking/test_click_through_rate.py::TestClickThroughRate::test_ctr_with_valid_input"
164 "tests/metrics/ranking/test_hit_rate.py::TestHitRate::test_hitrate_with_valid_input"
165 "tests/metrics/ranking/test_reciprocal_rank.py::TestReciprocalRank::test_mrr_with_valid_input"
166 "tests/metrics/ranking/test_retrieval_precision.py::TestRetrievalPrecision::test_retrieval_precision_multiple_updates_1_query"
167 "tests/metrics/ranking/test_retrieval_precision.py::TestRetrievalPrecision::test_retrieval_precision_multiple_updates_n_queries_without_nan"
168 "tests/metrics/ranking/test_weighted_calibration.py::TestWeightedCalibration::test_weighted_calibration_with_valid_input"
169 "tests/metrics/regression/test_mean_squared_error.py"
170 "tests/metrics/regression/test_r2_score.py"
171 "tests/metrics/test_synclib.py::SynclibTest::test_gather_uneven_multidim"
172 "tests/metrics/test_synclib.py::SynclibTest::test_gather_uneven"
173 "tests/metrics/test_synclib.py::SynclibTest::test_numeric_sync_state"
174 "tests/metrics/test_synclib.py::SynclibTest::test_sync_list_length"
175 "tests/metrics/text/test_bleu.py::TestBleu::test_bleu_multiple_examples_per_update"
176 "tests/metrics/text/test_bleu.py::TestBleu::test_bleu_multiple_updates"
177 "tests/metrics/text/test_perplexity.py::TestPerplexity::test_perplexity_with_ignore_index"
178 "tests/metrics/text/test_perplexity.py::TestPerplexity::test_perplexity"
179 "tests/metrics/text/test_word_error_rate.py::TestWordErrorRate::test_word_error_rate_with_valid_input"
180 "tests/metrics/text/test_word_information_lost.py::TestWordInformationLost::test_word_information_lost"
181 "tests/metrics/text/test_word_information_preserved.py::TestWordInformationPreserved::test_word_information_preserved_with_valid_input"
182 "tests/metrics/window/test_auroc.py"
183 "tests/metrics/window/test_click_through_rate.py::TestClickThroughRate::test_ctr_with_valid_input"
184 "tests/metrics/window/test_mean_squared_error.py"
185 "tests/metrics/window/test_normalized_entropy.py::TestWindowedBinaryNormalizedEntropy::test_ne_with_valid_input"
186 "tests/metrics/window/test_weighted_calibration.py::TestWindowedWeightedCalibration::test_weighted_calibration_with_valid_input"
187 ];
188
189 meta = {
190 description = "Rich collection of performant PyTorch model metrics and tools for PyTorch model evaluations";
191 homepage = "https://pytorch.org/torcheval";
192 changelog = "https://github.com/pytorch/torcheval/releases/tag/${version}";
193
194 platforms = lib.platforms.unix;
195 license = [ lib.licenses.bsd3 ];
196 maintainers = [ lib.maintainers.bengsparks ];
197 };
198}