1{
2 lib,
3 stdenv,
4 config,
5 buildPythonPackage,
6 fetchFromGitHub,
7
8 # build-system
9 setuptools,
10
11 # dependencies
12 numpy,
13 scikit-learn,
14 torch,
15 tqdm,
16
17 # optional-dependencies
18 faiss,
19 tensorboard,
20
21 # tests
22 pytestCheckHook,
23 torchvision,
24 writableTmpDirAsHomeHook,
25
26 cudaSupport ? config.cudaSupport,
27}:
28
29buildPythonPackage rec {
30 pname = "pytorch-metric-learning";
31 version = "2.9.0";
32 pyproject = true;
33
34 src = fetchFromGitHub {
35 owner = "KevinMusgrave";
36 repo = "pytorch-metric-learning";
37 tag = "v${version}";
38 hash = "sha256-JKWE2wVXVx8xp2kpiX6CxvCKkrwYRW80A20K/UTxIaQ=";
39 };
40
41 build-system = [
42 setuptools
43 ];
44
45 dependencies = [
46 numpy
47 torch
48 scikit-learn
49 tqdm
50 ];
51
52 optional-dependencies = {
53 with-hooks = [
54 # TODO: record-keeper
55 faiss
56 tensorboard
57 ];
58 with-hooks-cpu = [
59 # TODO: record-keeper
60 faiss
61 tensorboard
62 ];
63 };
64
65 preCheck = ''
66 export TEST_DEVICE=cpu
67 export TEST_DTYPES=float32,float64 # half-precision tests fail on CPU
68 '';
69
70 # package only requires `unittest`, but use `pytest` to exclude tests
71 nativeCheckInputs = [
72 pytestCheckHook
73 torchvision
74 writableTmpDirAsHomeHook
75 ]
76 ++ lib.flatten (lib.attrValues optional-dependencies);
77
78 disabledTests = [
79 # network access
80 "test_tuplestoweights_sampler"
81 "test_metric_loss_only"
82 "test_add_to_indexer"
83 "test_get_nearest_neighbors"
84 "test_list_of_text"
85 "test_untrained_indexer"
86 ]
87 ++ lib.optionals cudaSupport [
88 # crashes with SIGBART
89 "test_accuracy_calculator_and_faiss_with_torch_and_numpy"
90 "test_accuracy_calculator_large_k"
91 "test_custom_knn"
92 "test_global_embedding_space_tester"
93 "test_global_two_stream_embedding_space_tester"
94 "test_index_type"
95 "test_k_warning"
96 "test_many_tied_distances"
97 "test_query_within_reference"
98 "test_tied_distances"
99 "test_with_same_parent_label_tester"
100 ];
101
102 disabledTestPaths = lib.optionals stdenv.hostPlatform.isDarwin [
103 # Fatal Python error: Segmentation fault
104 "tests/testers/"
105 "tests/utils/"
106 ];
107
108 meta = {
109 description = "Metric learning library for PyTorch";
110 homepage = "https://github.com/KevinMusgrave/pytorch-metric-learning";
111 changelog = "https://github.com/KevinMusgrave/pytorch-metric-learning/releases/tag/${src.tag}";
112 license = lib.licenses.mit;
113 maintainers = with lib.maintainers; [ bcdarwin ];
114 };
115}