at master 9.3 kB view raw
1{ 2 lib, 3 stdenv, 4 buildPythonPackage, 5 fetchFromGitHub, 6 7 # build-system 8 setuptools, 9 10 # dependencies 11 typing-extensions, 12 13 # tests 14 cython, 15 numpy, 16 pytest-timeout, 17 pytest-xdist, 18 pytestCheckHook, 19 scikit-image, 20 scikit-learn, 21 torchtnt-nightly, 22 torchvision, 23}: 24let 25 pname = "torcheval"; 26 version = "0.0.7"; 27in 28buildPythonPackage { 29 inherit pname version; 30 pyproject = true; 31 32 src = fetchFromGitHub { 33 owner = "pytorch"; 34 repo = "torcheval"; 35 # Upstream has not created a tag for this version 36 # https://github.com/pytorch/torcheval/issues/215 37 rev = "f1bc22fc67ec2c77ee519aa4af8079f4fdaa41bb"; 38 hash = "sha256-aVr4qKKE+dpBcJEi1qZJBljFLUl8d7D306Dy8uOojJE="; 39 }; 40 41 # Patches are only applied to usages of numpy within tests, 42 # which are only used for testing purposes (see dev-requirements.txt) 43 postPatch = 44 # numpy's `np.NAN` was changed to `np.nan` when numpy 2 was released 45 '' 46 substituteInPlace tests/metrics/classification/test_accuracy.py tests/metrics/functional/classification/test_accuracy.py \ 47 --replace-fail "np.NAN" "np.nan" 48 '' 49 50 # `unittest.TestCase.assertEquals` does not exist; 51 # the correct symbol is `unittest.TestCase.assertEqual` 52 + '' 53 substituteInPlace tests/metrics/test_synclib.py \ 54 --replace-fail "tc.assertEquals" "tc.assertEqual" 55 ''; 56 57 build-system = [ setuptools ]; 58 59 dependencies = [ typing-extensions ]; 60 61 pythonImportsCheck = [ "torcheval" ]; 62 63 nativeCheckInputs = [ 64 cython 65 numpy 66 pytest-timeout 67 pytest-xdist 68 pytestCheckHook 69 scikit-image 70 scikit-learn 71 torchtnt-nightly 72 torchvision 73 ]; 74 75 pytestFlags = [ 76 "-v" 77 ]; 78 79 enabledTestPaths = [ 80 "tests/" 81 ]; 82 83 disabledTestPaths = [ 84 # -- tests/metrics/audio/test_fad.py -- 85 # Touch filesystem and require network access. 86 # torchaudio.utils.download_asset("models/vggish.pt") -> PermissionError: [Errno 13] Permission denied: '/homeless-shelter' 87 "tests/metrics/audio/test_fad.py::TestFAD::test_vggish_fad" 88 "tests/metrics/audio/test_fad.py::TestFAD::test_vggish_fad_merge" 89 90 # -- tests/metrics/image/test_fid.py -- 91 # Touch filesystem and require network access. 92 # models.inception_v3(weights=weights) -> PermissionError: [Errno 13] Permission denied: '/homeless-shelter' 93 "tests/metrics/image/test_fid.py::TestFrechetInceptionDistance::test_fid_invalid_input" 94 "tests/metrics/image/test_fid.py::TestFrechetInceptionDistance::test_fid_random_data_custom_model" 95 "tests/metrics/image/test_fid.py::TestFrechetInceptionDistance::test_fid_random_data_default_model" 96 "tests/metrics/image/test_fid.py::TestFrechetInceptionDistance::test_fid_with_dissimilar_inputs" 97 "tests/metrics/image/test_fid.py::TestFrechetInceptionDistance::test_fid_with_similar_inputs" 98 99 # -- tests/metrics/functional/text/test_perplexity.py -- 100 # AssertionError: Scalars are not close! 101 # Expected 3.537154912949 but got 3.53715443611145 102 "tests/metrics/functional/text/test_perplexity.py::Perplexity::test_perplexity_with_ignore_index" 103 104 # -- tests/metrics/image/test_psnr.py -- 105 # AssertionError: Scalars are not close! 106 # Expected 7.781850814819336 but got 7.781772613525391 107 "tests/metrics/image/test_psnr.py::TestPeakSignalNoiseRatio::test_psnr_with_random_data" 108 109 # -- tests/metrics/regression/test_mean_squared_error.py -- 110 # AssertionError: Scalars are not close! 111 # Expected -640.4547729492188 but got -640.4707641601562 112 "tests/metrics/regression/test_mean_squared_error.py::TestMeanSquaredError::test_mean_squared_error_class_update_input_shape_different" 113 114 # -- tests/metrics/window/test_mean_squared_error.py -- 115 # AssertionError: Scalars are not close! 116 # Expected 0.0009198983898386359 but got 0.0009198188781738281 117 "tests/metrics/window/test_mean_squared_error.py::TestMeanSquaredError::test_mean_squared_error_class_update_input_shape_different" 118 ] 119 120 # These tests error on darwin platforms. 121 # NotImplementedError: The operator 'c10d::allgather_' is not currently implemented for the mps device 122 # 123 # Applying the suggested environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1;` causes the tests to fail, 124 # as using the CPU instead of the MPS causes the tensors to be on the wrong device: 125 # RuntimeError: ProcessGroupGloo::allgather: invalid tensor type at index 0; 126 # Expected TensorOptions(dtype=float, device=cpu, ...), got TensorOptions(dtype=float, device=mps:0, ...) 127 ++ lib.optional stdenv.hostPlatform.isDarwin [ 128 # -- tests/metrics/test_synclib.py -- 129 "tests/metrics/test_synclib.py::SynclibTest::test_complex_mixed_state_sync" 130 "tests/metrics/test_synclib.py::SynclibTest::test_complex_mixed_state_sync" 131 "tests/metrics/test_synclib.py::SynclibTest::test_empty_tensor_list_sync_state" 132 "tests/metrics/test_synclib.py::SynclibTest::test_sync_dtype_and_shape" 133 "tests/metrics/test_synclib.py::SynclibTest::test_tensor_list_sync_states" 134 "tests/metrics/test_synclib.py::SynclibTest::test_tensor_dict_sync_states" 135 "tests/metrics/test_synclib.py::SynclibTest::test_tensor_sync_states" 136 # -- tests/metrics/test_toolkit.py -- 137 "tests/metrics/test_toolkit.py::MetricToolkitTest::test_metric_sync" 138 "tests/metrics/test_toolkit.py::MetricCollectionToolkitTest::test_metric_collection_sync" 139 140 # Cannot access local process over IPv6 (nodename nor servname provided) even with __darwinAllowLocalNetworking 141 # Will hang, or appear to hang, with an 5 minute (default) timeout per test 142 "tests/metrics/aggregation/test_auc.py" 143 "tests/metrics/aggregation/test_cat.py" 144 "tests/metrics/aggregation/test_max.py" 145 "tests/metrics/aggregation/test_mean.py" 146 "tests/metrics/aggregation/test_min.py" 147 "tests/metrics/aggregation/test_sum.py" 148 "tests/metrics/aggregation/test_throughput.py" 149 "tests/metrics/classification/test_accuracy.py" 150 "tests/metrics/classification/test_auprc.py" 151 "tests/metrics/classification/test_auroc.py" 152 "tests/metrics/classification/test_binned_auprc.py" 153 "tests/metrics/classification/test_binned_auroc.py" 154 "tests/metrics/classification/test_binned_precision_recall_curve.py" 155 "tests/metrics/classification/test_confusion_matrix.py" 156 "tests/metrics/classification/test_f1_score.py" 157 "tests/metrics/classification/test_normalized_entropy.py" 158 "tests/metrics/classification/test_precision_recall_curve.py" 159 "tests/metrics/classification/test_precision.py" 160 "tests/metrics/classification/test_recall_at_fixed_precision.py" 161 "tests/metrics/classification/test_recall.py" 162 "tests/metrics/functional/classification/test_auroc.py" 163 "tests/metrics/ranking/test_click_through_rate.py::TestClickThroughRate::test_ctr_with_valid_input" 164 "tests/metrics/ranking/test_hit_rate.py::TestHitRate::test_hitrate_with_valid_input" 165 "tests/metrics/ranking/test_reciprocal_rank.py::TestReciprocalRank::test_mrr_with_valid_input" 166 "tests/metrics/ranking/test_retrieval_precision.py::TestRetrievalPrecision::test_retrieval_precision_multiple_updates_1_query" 167 "tests/metrics/ranking/test_retrieval_precision.py::TestRetrievalPrecision::test_retrieval_precision_multiple_updates_n_queries_without_nan" 168 "tests/metrics/ranking/test_weighted_calibration.py::TestWeightedCalibration::test_weighted_calibration_with_valid_input" 169 "tests/metrics/regression/test_mean_squared_error.py" 170 "tests/metrics/regression/test_r2_score.py" 171 "tests/metrics/test_synclib.py::SynclibTest::test_gather_uneven_multidim" 172 "tests/metrics/test_synclib.py::SynclibTest::test_gather_uneven" 173 "tests/metrics/test_synclib.py::SynclibTest::test_numeric_sync_state" 174 "tests/metrics/test_synclib.py::SynclibTest::test_sync_list_length" 175 "tests/metrics/text/test_bleu.py::TestBleu::test_bleu_multiple_examples_per_update" 176 "tests/metrics/text/test_bleu.py::TestBleu::test_bleu_multiple_updates" 177 "tests/metrics/text/test_perplexity.py::TestPerplexity::test_perplexity_with_ignore_index" 178 "tests/metrics/text/test_perplexity.py::TestPerplexity::test_perplexity" 179 "tests/metrics/text/test_word_error_rate.py::TestWordErrorRate::test_word_error_rate_with_valid_input" 180 "tests/metrics/text/test_word_information_lost.py::TestWordInformationLost::test_word_information_lost" 181 "tests/metrics/text/test_word_information_preserved.py::TestWordInformationPreserved::test_word_information_preserved_with_valid_input" 182 "tests/metrics/window/test_auroc.py" 183 "tests/metrics/window/test_click_through_rate.py::TestClickThroughRate::test_ctr_with_valid_input" 184 "tests/metrics/window/test_mean_squared_error.py" 185 "tests/metrics/window/test_normalized_entropy.py::TestWindowedBinaryNormalizedEntropy::test_ne_with_valid_input" 186 "tests/metrics/window/test_weighted_calibration.py::TestWindowedWeightedCalibration::test_weighted_calibration_with_valid_input" 187 ]; 188 189 meta = { 190 description = "Rich collection of performant PyTorch model metrics and tools for PyTorch model evaluations"; 191 homepage = "https://pytorch.org/torcheval"; 192 changelog = "https://github.com/pytorch/torcheval/releases/tag/${version}"; 193 194 platforms = lib.platforms.unix; 195 license = [ lib.licenses.bsd3 ]; 196 maintainers = [ lib.maintainers.bengsparks ]; 197 }; 198}