at master 1.1 kB view raw
1{ 2 lib, 3 buildPythonPackage, 4 fetchFromGitHub, 5 6 # build-system 7 hatchling, 8 9 # dependencies 10 equinox, 11 jax, 12 jaxtyping, 13 optax, 14 paramax, 15 tqdm, 16 17 # tests 18 beartype, 19 numpyro, 20 pytest-xdist, 21 pytestCheckHook, 22}: 23 24buildPythonPackage rec { 25 pname = "flowjax"; 26 version = "17.2.0"; 27 pyproject = true; 28 29 src = fetchFromGitHub { 30 owner = "danielward27"; 31 repo = "flowjax"; 32 tag = "v${version}"; 33 hash = "sha256-gaHlXm1M41njtgQt+f77Wd7q+PQ+1ipZiLtv59z1ma4="; 34 }; 35 36 build-system = [ 37 hatchling 38 ]; 39 40 dependencies = [ 41 equinox 42 jax 43 jaxtyping 44 optax 45 paramax 46 tqdm 47 ]; 48 49 pythonImportsCheck = [ "flowjax" ]; 50 51 nativeCheckInputs = [ 52 beartype 53 numpyro 54 pytest-xdist 55 pytestCheckHook 56 ]; 57 58 meta = { 59 description = "Distributions, bijections and normalizing flows using Equinox and JAX"; 60 homepage = "https://github.com/danielward27/flowjax"; 61 changelog = "https://github.com/danielward27/flowjax/releases/tag/${src.tag}"; 62 license = lib.licenses.mit; 63 maintainers = with lib.maintainers; [ GaetanLepage ]; 64 }; 65}