1{
2 lib,
3 stdenv,
4 buildPythonPackage,
5 fetchFromGitHub,
6 pythonOlder,
7 numpy,
8 scikit-learn,
9 scipy,
10 setuptools,
11 tabulate,
12 torch,
13 tqdm,
14 flaky,
15 llvmPackages,
16 pandas,
17 pytest-cov-stub,
18 pytestCheckHook,
19 safetensors,
20 transformers,
21 pythonAtLeast,
22}:
23
24buildPythonPackage rec {
25 pname = "skorch";
26 version = "1.1.0";
27 pyproject = true;
28
29 src = fetchFromGitHub {
30 owner = "skorch-dev";
31 repo = "skorch";
32 tag = "v${version}";
33 sha256 = "sha256-f0g/kn3HhvYfGDgLpA7gAnYocJrYqHUq680KrGuoPCQ=";
34 };
35
36 # AttributeError: 'NoneType' object has no attribute 'span' with Python 3.13
37 # https://github.com/skorch-dev/skorch/issues/1080
38 disabled = pythonOlder "3.9" || pythonAtLeast "3.13";
39
40 build-system = [ setuptools ];
41
42 dependencies = [
43 numpy
44 pandas
45 scikit-learn
46 scipy
47 tabulate
48 torch # implicit dependency
49 tqdm
50 ];
51
52 nativeCheckInputs = [
53 flaky
54 pytest-cov-stub
55 pytestCheckHook
56 safetensors
57 transformers
58 ];
59
60 checkInputs = lib.optionals stdenv.cc.isClang [ llvmPackages.openmp ];
61
62 disabledTests = [
63 # on CPU, these expect artifacts from previous GPU run
64 "test_load_cuda_params_to_cpu"
65 # failing tests
66 "test_pickle_load"
67 # there is a problem with the compiler selection
68 "test_fit_and_predict_with_compile"
69 # "Weights only load failed"
70 "test_can_be_copied"
71 "test_pickle"
72 "test_pickle_save_load"
73 "test_train_net_after_copy"
74 "test_weights_restore"
75 # Reported as flaky
76 "test_fit_lbfgs_optimizer"
77 ];
78
79 disabledTestPaths = [
80 # tries to download missing HuggingFace data
81 "skorch/tests/test_dataset.py"
82 "skorch/tests/test_hf.py"
83 "skorch/tests/llm/test_llm_classifier.py"
84 # These tests fail when running in parallel for all platforms with:
85 # "RuntimeError: The server socket has failed to listen on any local
86 # network address because they use the same hardcoded port."
87 # This happens on every platform with sandboxing enabled.
88 "skorch/tests/test_history.py"
89 ];
90
91 pythonImportsCheck = [ "skorch" ];
92
93 meta = {
94 description = "Scikit-learn compatible neural net library using Pytorch";
95 homepage = "https://skorch.readthedocs.io";
96 changelog = "https://github.com/skorch-dev/skorch/blob/master/CHANGES.md";
97 license = lib.licenses.bsd3;
98 maintainers = with lib.maintainers; [ bcdarwin ];
99 badPlatforms = [
100 # Most tests fail with:
101 # Fatal Python error: Segmentation fault
102 lib.systems.inspect.patterns.isDarwin
103 ];
104 };
105}