1{ 2 lib, 3 stdenv, 4 config, 5 buildPythonPackage, 6 fetchFromGitHub, 7 8 # build-system 9 setuptools, 10 11 # dependencies 12 numpy, 13 scikit-learn, 14 torch, 15 tqdm, 16 17 # optional-dependencies 18 faiss, 19 tensorboard, 20 21 # tests 22 pytestCheckHook, 23 torchvision, 24 writableTmpDirAsHomeHook, 25 26 cudaSupport ? config.cudaSupport, 27}: 28 29buildPythonPackage rec { 30 pname = "pytorch-metric-learning"; 31 version = "2.9.0"; 32 pyproject = true; 33 34 src = fetchFromGitHub { 35 owner = "KevinMusgrave"; 36 repo = "pytorch-metric-learning"; 37 tag = "v${version}"; 38 hash = "sha256-JKWE2wVXVx8xp2kpiX6CxvCKkrwYRW80A20K/UTxIaQ="; 39 }; 40 41 build-system = [ 42 setuptools 43 ]; 44 45 dependencies = [ 46 numpy 47 torch 48 scikit-learn 49 tqdm 50 ]; 51 52 optional-dependencies = { 53 with-hooks = [ 54 # TODO: record-keeper 55 faiss 56 tensorboard 57 ]; 58 with-hooks-cpu = [ 59 # TODO: record-keeper 60 faiss 61 tensorboard 62 ]; 63 }; 64 65 preCheck = '' 66 export TEST_DEVICE=cpu 67 export TEST_DTYPES=float32,float64 # half-precision tests fail on CPU 68 ''; 69 70 # package only requires `unittest`, but use `pytest` to exclude tests 71 nativeCheckInputs = [ 72 pytestCheckHook 73 torchvision 74 writableTmpDirAsHomeHook 75 ] 76 ++ lib.flatten (lib.attrValues optional-dependencies); 77 78 disabledTests = [ 79 # network access 80 "test_tuplestoweights_sampler" 81 "test_metric_loss_only" 82 "test_add_to_indexer" 83 "test_get_nearest_neighbors" 84 "test_list_of_text" 85 "test_untrained_indexer" 86 ] 87 ++ lib.optionals cudaSupport [ 88 # crashes with SIGBART 89 "test_accuracy_calculator_and_faiss_with_torch_and_numpy" 90 "test_accuracy_calculator_large_k" 91 "test_custom_knn" 92 "test_global_embedding_space_tester" 93 "test_global_two_stream_embedding_space_tester" 94 "test_index_type" 95 "test_k_warning" 96 "test_many_tied_distances" 97 "test_query_within_reference" 98 "test_tied_distances" 99 "test_with_same_parent_label_tester" 100 ]; 101 102 disabledTestPaths = lib.optionals stdenv.hostPlatform.isDarwin [ 103 # Fatal Python error: Segmentation fault 104 "tests/testers/" 105 "tests/utils/" 106 ]; 107 108 meta = { 109 description = "Metric learning library for PyTorch"; 110 homepage = "https://github.com/KevinMusgrave/pytorch-metric-learning"; 111 changelog = "https://github.com/KevinMusgrave/pytorch-metric-learning/releases/tag/${src.tag}"; 112 license = lib.licenses.mit; 113 maintainers = with lib.maintainers; [ bcdarwin ]; 114 }; 115}