at master 3.9 kB view raw
1{ 2 lib, 3 buildPythonPackage, 4 fetchFromGitHub, 5 6 # build-system 7 setuptools, 8 9 # dependencies 10 filelock, 11 huggingface-hub, 12 importlib-metadata, 13 numpy, 14 pillow, 15 regex, 16 requests, 17 safetensors, 18 19 # optional dependencies 20 accelerate, 21 datasets, 22 flax, 23 jax, 24 jaxlib, 25 jinja2, 26 peft, 27 protobuf, 28 tensorboard, 29 torch, 30 31 # tests 32 writeText, 33 parameterized, 34 pytest-timeout, 35 pytest-xdist, 36 pytestCheckHook, 37 requests-mock, 38 scipy, 39 sentencepiece, 40 torchsde, 41 transformers, 42 pythonAtLeast, 43 diffusers, 44}: 45 46buildPythonPackage rec { 47 pname = "diffusers"; 48 version = "0.35.1"; 49 pyproject = true; 50 51 src = fetchFromGitHub { 52 owner = "huggingface"; 53 repo = "diffusers"; 54 tag = "v${version}"; 55 hash = "sha256-VZXf1YCIFtzuBWaeYG3A+AyqnMEAKEI2nStjuPJ8ZTk="; 56 }; 57 58 build-system = [ setuptools ]; 59 60 dependencies = [ 61 filelock 62 huggingface-hub 63 importlib-metadata 64 numpy 65 pillow 66 regex 67 requests 68 safetensors 69 ]; 70 71 optional-dependencies = { 72 flax = [ 73 flax 74 jax 75 jaxlib 76 ]; 77 torch = [ 78 accelerate 79 torch 80 ]; 81 training = [ 82 accelerate 83 datasets 84 jinja2 85 peft 86 protobuf 87 tensorboard 88 ]; 89 }; 90 91 pythonImportsCheck = [ "diffusers" ]; 92 93 # it takes a few hours 94 doCheck = false; 95 96 nativeCheckInputs = [ 97 parameterized 98 pytest-timeout 99 pytest-xdist 100 pytestCheckHook 101 requests-mock 102 scipy 103 sentencepiece 104 torchsde 105 transformers 106 ] 107 ++ lib.flatten (builtins.attrValues optional-dependencies); 108 109 preCheck = 110 let 111 # This pytest hook mocks and catches attempts at accessing the network 112 # tests that try to access the network will raise, get caught, be marked as skipped and tagged as xfailed. 113 # cf. python3Packages.shap 114 conftestSkipNetworkErrors = writeText "conftest.py" '' 115 from _pytest.runner import pytest_runtest_makereport as orig_pytest_runtest_makereport 116 import urllib3 117 118 class NetworkAccessDeniedError(RuntimeError): pass 119 def deny_network_access(*a, **kw): 120 raise NetworkAccessDeniedError 121 122 urllib3.connection.HTTPSConnection._new_conn = deny_network_access 123 124 def pytest_runtest_makereport(item, call): 125 tr = orig_pytest_runtest_makereport(item, call) 126 if call.excinfo is not None and call.excinfo.type is NetworkAccessDeniedError: 127 tr.outcome = 'skipped' 128 tr.wasxfail = "reason: Requires network access." 129 return tr 130 ''; 131 in 132 '' 133 export HOME=$(mktemp -d) 134 cat ${conftestSkipNetworkErrors} >> tests/conftest.py 135 ''; 136 137 enabledTestPaths = [ "tests/" ]; 138 139 disabledTests = [ 140 # depends on current working directory 141 "test_deprecate_stacklevel" 142 # fails due to precision of floating point numbers 143 "test_full_loop_no_noise" 144 "test_model_cpu_offload_forward_pass" 145 # tries to run ruff which we have intentionally removed from nativeCheckInputs 146 "test_is_copy_consistent" 147 148 # Require unpackaged torchao: 149 # importlib.metadata.PackageNotFoundError: No package metadata was found for torchao 150 "test_load_attn_procs_raise_warning" 151 "test_save_attn_procs_raise_warning" 152 "test_save_load_lora_adapter_0" 153 "test_save_load_lora_adapter_1" 154 "test_wrong_adapter_name_raises_error" 155 ] 156 ++ lib.optionals (pythonAtLeast "3.13") [ 157 # RuntimeError: Dynamo is not supported on Python 3.12+ 158 "test_from_save_pretrained_dynamo" 159 ]; 160 161 passthru.tests.pytest = diffusers.overridePythonAttrs { doCheck = true; }; 162 163 meta = { 164 description = "State-of-the-art diffusion models for image and audio generation in PyTorch"; 165 mainProgram = "diffusers-cli"; 166 homepage = "https://github.com/huggingface/diffusers"; 167 changelog = "https://github.com/huggingface/diffusers/releases/tag/${src.tag}"; 168 license = lib.licenses.asl20; 169 maintainers = with lib.maintainers; [ natsukium ]; 170 }; 171}