1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5 flax,
6 jax,
7 jaxlib,
8 transformers,
9}:
10
11buildPythonPackage {
12 pname = "vqgan-jax";
13 version = "unstable-2022-04-20";
14
15 src = fetchFromGitHub {
16 owner = "patil-suraj";
17 repo = "vqgan-jax";
18 rev = "1be20eee476e5d35c30e4ec3ed12222018af8ce4";
19 hash = "sha256-OZihAXpE0UsgauQ38XDmAF+lrIgz05uK0ro8SCdVsPc=";
20 };
21
22 format = "setuptools";
23
24 buildInputs = [ jaxlib ];
25
26 propagatedBuildInputs = [
27 flax
28 jax
29 transformers
30 ];
31
32 doCheck = false;
33
34 pythonImportsCheck = [ "vqgan_jax" ];
35
36 meta = with lib; {
37 description = "JAX implementation of VQGAN";
38 homepage = "https://github.com/patil-suraj/vqgan-jax";
39 # license unknown: https://github.com/patil-suraj/vqgan-jax/issues/9
40 license = lib.licenses.unfree;
41 maintainers = with maintainers; [ r-burns ];
42 };
43}