1{
2 lib,
3 stdenv,
4 buildPythonPackage,
5 fetchFromGitHub,
6
7 # build-system
8 setuptools,
9
10 # dependencies
11 torch,
12
13 # tests
14 bitsandbytes,
15 expecttest,
16 fire,
17 pytest-xdist,
18 pytestCheckHook,
19 parameterized,
20 tabulate,
21 transformers,
22 unittest-xml-reporting,
23}:
24
25buildPythonPackage rec {
26 pname = "ao";
27 version = "0.13.0";
28 pyproject = true;
29
30 src = fetchFromGitHub {
31 owner = "pytorch";
32 repo = "ao";
33 tag = "v${version}";
34 hash = "sha256-R9H4+KkKuOzsunM3A5LT8upH1TfkHrD+BZerToCHwjo=";
35 };
36
37 build-system = [
38 setuptools
39 ];
40
41 dependencies = [
42 torch
43 ];
44
45 env = {
46 USE_SYSTEM_LIBS = true;
47 };
48
49 # Otherwise, the tests are loading the python module from the source instead of the installed one
50 preCheck = ''
51 rm -rf torchao
52 '';
53
54 pythonImportsCheck = [
55 "torchao"
56 ];
57
58 nativeCheckInputs = [
59 bitsandbytes
60 expecttest
61 fire
62 parameterized
63 pytest-xdist
64 pytestCheckHook
65 tabulate
66 transformers
67 unittest-xml-reporting
68 ];
69
70 disabledTests = [
71 # Requires internet access
72 "test_on_dummy_distilbert"
73
74 # FileNotFoundError: [Errno 2] No such file or directory: 'checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth'
75 "test_gptq_mt"
76 ]
77 ++ lib.optionals (stdenv.hostPlatform.isLinux && stdenv.hostPlatform.isAarch64) [
78 # AssertionError: tensor(False) is not true
79 "test_quantize_per_token_cpu"
80
81 # RuntimeError: failed to initialize QNNPACK
82 "test_smooth_linear_cpu"
83
84 # torch._inductor.exc.InductorError: LoweringException: AssertionError: Expect L1_cache_size > 0 but got 0
85 "test_int8_weight_only_quant_with_freeze_0_cpu"
86 "test_int8_weight_only_quant_with_freeze_1_cpu"
87 "test_int8_weight_only_quant_with_freeze_2_cpu"
88
89 # FileNotFoundError: [Errno 2] No such file or directory: 'test.pth'
90 "test_save_load_int4woqtensors_2_cpu"
91 "test_save_load_int8woqtensors_0_cpu"
92 "test_save_load_int8woqtensors_1_cpu"
93 ];
94
95 meta = {
96 description = "PyTorch native quantization and sparsity for training and inference";
97 homepage = "https://github.com/pytorch/ao";
98 changelog = "https://github.com/pytorch/ao/releases/tag/v${version}";
99 license = lib.licenses.bsd3;
100 maintainers = with lib.maintainers; [ GaetanLepage ];
101 };
102}