1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5
6 # build-system
7 setuptools,
8
9 # dependencies
10 filelock,
11 huggingface-hub,
12 importlib-metadata,
13 numpy,
14 pillow,
15 regex,
16 requests,
17 safetensors,
18
19 # optional dependencies
20 accelerate,
21 datasets,
22 flax,
23 jax,
24 jaxlib,
25 jinja2,
26 peft,
27 protobuf,
28 tensorboard,
29 torch,
30
31 # tests
32 writeText,
33 parameterized,
34 pytest-timeout,
35 pytest-xdist,
36 pytestCheckHook,
37 requests-mock,
38 scipy,
39 sentencepiece,
40 torchsde,
41 transformers,
42 pythonAtLeast,
43 diffusers,
44}:
45
46buildPythonPackage rec {
47 pname = "diffusers";
48 version = "0.35.1";
49 pyproject = true;
50
51 src = fetchFromGitHub {
52 owner = "huggingface";
53 repo = "diffusers";
54 tag = "v${version}";
55 hash = "sha256-VZXf1YCIFtzuBWaeYG3A+AyqnMEAKEI2nStjuPJ8ZTk=";
56 };
57
58 build-system = [ setuptools ];
59
60 dependencies = [
61 filelock
62 huggingface-hub
63 importlib-metadata
64 numpy
65 pillow
66 regex
67 requests
68 safetensors
69 ];
70
71 optional-dependencies = {
72 flax = [
73 flax
74 jax
75 jaxlib
76 ];
77 torch = [
78 accelerate
79 torch
80 ];
81 training = [
82 accelerate
83 datasets
84 jinja2
85 peft
86 protobuf
87 tensorboard
88 ];
89 };
90
91 pythonImportsCheck = [ "diffusers" ];
92
93 # it takes a few hours
94 doCheck = false;
95
96 nativeCheckInputs = [
97 parameterized
98 pytest-timeout
99 pytest-xdist
100 pytestCheckHook
101 requests-mock
102 scipy
103 sentencepiece
104 torchsde
105 transformers
106 ]
107 ++ lib.flatten (builtins.attrValues optional-dependencies);
108
109 preCheck =
110 let
111 # This pytest hook mocks and catches attempts at accessing the network
112 # tests that try to access the network will raise, get caught, be marked as skipped and tagged as xfailed.
113 # cf. python3Packages.shap
114 conftestSkipNetworkErrors = writeText "conftest.py" ''
115 from _pytest.runner import pytest_runtest_makereport as orig_pytest_runtest_makereport
116 import urllib3
117
118 class NetworkAccessDeniedError(RuntimeError): pass
119 def deny_network_access(*a, **kw):
120 raise NetworkAccessDeniedError
121
122 urllib3.connection.HTTPSConnection._new_conn = deny_network_access
123
124 def pytest_runtest_makereport(item, call):
125 tr = orig_pytest_runtest_makereport(item, call)
126 if call.excinfo is not None and call.excinfo.type is NetworkAccessDeniedError:
127 tr.outcome = 'skipped'
128 tr.wasxfail = "reason: Requires network access."
129 return tr
130 '';
131 in
132 ''
133 export HOME=$(mktemp -d)
134 cat ${conftestSkipNetworkErrors} >> tests/conftest.py
135 '';
136
137 enabledTestPaths = [ "tests/" ];
138
139 disabledTests = [
140 # depends on current working directory
141 "test_deprecate_stacklevel"
142 # fails due to precision of floating point numbers
143 "test_full_loop_no_noise"
144 "test_model_cpu_offload_forward_pass"
145 # tries to run ruff which we have intentionally removed from nativeCheckInputs
146 "test_is_copy_consistent"
147
148 # Require unpackaged torchao:
149 # importlib.metadata.PackageNotFoundError: No package metadata was found for torchao
150 "test_load_attn_procs_raise_warning"
151 "test_save_attn_procs_raise_warning"
152 "test_save_load_lora_adapter_0"
153 "test_save_load_lora_adapter_1"
154 "test_wrong_adapter_name_raises_error"
155 ]
156 ++ lib.optionals (pythonAtLeast "3.13") [
157 # RuntimeError: Dynamo is not supported on Python 3.12+
158 "test_from_save_pretrained_dynamo"
159 ];
160
161 passthru.tests.pytest = diffusers.overridePythonAttrs { doCheck = true; };
162
163 meta = {
164 description = "State-of-the-art diffusion models for image and audio generation in PyTorch";
165 mainProgram = "diffusers-cli";
166 homepage = "https://github.com/huggingface/diffusers";
167 changelog = "https://github.com/huggingface/diffusers/releases/tag/${src.tag}";
168 license = lib.licenses.asl20;
169 maintainers = with lib.maintainers; [ natsukium ];
170 };
171}