1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5 fsspec,
6 numpy,
7 poetry-core,
8 pytestCheckHook,
9 scikit-learn,
10 scipy,
11 torch,
12 tqdm,
13 typing-extensions,
14}:
15
16buildPythonPackage rec {
17 pname = "pytorch-tabnet";
18 version = "4.1.0";
19 pyproject = true;
20
21 src = fetchFromGitHub {
22 owner = "dreamquark-ai";
23 repo = "tabnet";
24 tag = "v${version}";
25 hash = "sha256-WyNGgAkNn5CaEuHWQ6Fjnvnrp+KONnxUQudd5ckvcsM=";
26 };
27
28 # Modernize poetry build setup
29 postPatch = ''
30 substituteInPlace pyproject.toml \
31 --replace-fail 'requires = ["poetry>=0.12"]' 'requires = ["poetry-core"]' \
32 --replace-fail 'build-backend = "poetry.masonry.api"' 'build-backend = "poetry.core.masonry.api"'
33 '';
34
35 build-system = [ poetry-core ];
36
37 dependencies = [
38 numpy
39 scikit-learn
40 scipy
41 torch
42 tqdm
43 typing-extensions
44 ];
45
46 nativeCheckInputs = [
47 pytestCheckHook
48 fsspec
49 ];
50
51 pythonImportsCheck = [ "pytorch_tabnet" ];
52
53 meta = {
54 description = "PyTorch implementation of TabNet";
55 homepage = "https://github.com/dreamquark-ai/tabnet";
56 changelog = "https://github.com/dreamquark-ai/tabnet/releases/tag/v${version}";
57 license = lib.licenses.mit;
58 maintainers = with lib.maintainers; [ jherland ];
59 };
60}