1{
2 lib,
3 stdenv,
4 buildPythonPackage,
5 fetchFromGitHub,
6 fetchpatch,
7 pytestCheckHook,
8 setuptools,
9 apricot-select,
10 networkx,
11 numpy,
12 scikit-learn,
13 scipy,
14 torch,
15}:
16
17buildPythonPackage rec {
18 pname = "pomegranate";
19 version = "1.1.2";
20 pyproject = true;
21
22 src = fetchFromGitHub {
23 repo = "pomegranate";
24 owner = "jmschrei";
25 tag = "v${version}";
26 hash = "sha256-p2Gn0FXnsAHvRUeAqx4M1KH0+XvDl3fmUZZ7MiMvPSs=";
27 };
28
29 build-system = [ setuptools ];
30
31 dependencies = [
32 apricot-select
33 networkx
34 numpy
35 scikit-learn
36 scipy
37 torch
38 ];
39
40 pythonImportsCheck = [ "pomegranate" ];
41
42 nativeCheckInputs = [
43 pytestCheckHook
44 ];
45
46 patches = [
47 # Fix tests for pytorch 2.6
48 (fetchpatch {
49 name = "python-2.6.patch";
50 url = "https://github.com/jmschrei/pomegranate/pull/1142/commits/9ff5d5e2c959b44e569937e777b26184d1752a7b.patch";
51 hash = "sha256-BXsVhkuL27QqK/n6Fa9oJCzrzNcL3EF6FblBeKXXSts=";
52 })
53 ];
54
55 pytestFlagsArray = lib.optionals (stdenv.hostPlatform.isDarwin && stdenv.hostPlatform.isx86_64) [
56 # AssertionError: Arrays are not almost equal to 6 decimals
57 "--deselect=tests/distributions/test_normal_full.py::test_fit"
58 "--deselect=tests/distributions/test_normal_full.py::test_from_summaries"
59 "--deselect=tests/distributions/test_normal_full.py::test_serialization"
60 ];
61
62 disabledTests = [
63 # AssertionError: Arrays are not almost equal to 6 decimals
64 "test_sample"
65 ];
66
67 meta = {
68 description = "Probabilistic and graphical models for Python, implemented in cython for speed";
69 homepage = "https://github.com/jmschrei/pomegranate";
70 changelog = "https://github.com/jmschrei/pomegranate/releases/tag/v${version}";
71 license = lib.licenses.mit;
72 maintainers = with lib.maintainers; [ rybern ];
73 };
74}