1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5
6 # build-system
7 hatchling,
8
9 # dependencies
10 equinox,
11 jax,
12 jaxtyping,
13 optax,
14 paramax,
15 tqdm,
16
17 # tests
18 beartype,
19 numpyro,
20 pytest-xdist,
21 pytestCheckHook,
22}:
23
24buildPythonPackage rec {
25 pname = "flowjax";
26 version = "17.2.0";
27 pyproject = true;
28
29 src = fetchFromGitHub {
30 owner = "danielward27";
31 repo = "flowjax";
32 tag = "v${version}";
33 hash = "sha256-gaHlXm1M41njtgQt+f77Wd7q+PQ+1ipZiLtv59z1ma4=";
34 };
35
36 build-system = [
37 hatchling
38 ];
39
40 dependencies = [
41 equinox
42 jax
43 jaxtyping
44 optax
45 paramax
46 tqdm
47 ];
48
49 pythonImportsCheck = [ "flowjax" ];
50
51 nativeCheckInputs = [
52 beartype
53 numpyro
54 pytest-xdist
55 pytestCheckHook
56 ];
57
58 meta = {
59 description = "Distributions, bijections and normalizing flows using Equinox and JAX";
60 homepage = "https://github.com/danielward27/flowjax";
61 changelog = "https://github.com/danielward27/flowjax/releases/tag/${src.tag}";
62 license = lib.licenses.mit;
63 maintainers = with lib.maintainers; [ GaetanLepage ];
64 };
65}