1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5
6 # build-system
7 setuptools,
8 setuptools-scm,
9
10 # dependencies
11 torch,
12 triton,
13
14 # optional-dependencies
15 accelerate,
16 datasets,
17 fire,
18 huggingface-hub,
19 pandas,
20 pytestCheckHook,
21 tqdm,
22 transformers,
23}:
24
25buildPythonPackage {
26 pname = "cut-cross-entropy";
27 version = "25.7.2";
28 pyproject = true;
29
30 # The `ml-cross-entropy` Pypi comes from a third-party.
31 # Apple recommends installing from the repo's main branch directly
32 src = fetchFromGitHub {
33 owner = "apple";
34 repo = "ml-cross-entropy";
35 rev = "b19a424ed30a05b8261cfa84d83b2601a9454c67"; # no tags
36 hash = "sha256-AwUqKiI7XjEOZ7ofjQCOsqvxHyTFD4RZ70odPyxxntc=";
37 };
38
39 build-system = [
40 setuptools
41 setuptools-scm
42 ];
43
44 dependencies = [
45 torch
46 triton
47 ];
48
49 optional-dependencies = {
50 transformers = [ transformers ];
51 all = [
52 accelerate
53 datasets
54 fire
55 huggingface-hub
56 pandas
57 tqdm
58 transformers
59 ];
60 # `deepspeed` is not yet packaged in nixpkgs
61 # ++ lib.optionals (!stdenv.hostPlatform.isDarwin) [
62 # deepspeed
63 # ];
64 };
65
66 nativeCheckInputs = [ pytestCheckHook ];
67
68 disabledTests = [
69 "test_vocab_parallel" # Requires CUDA but does not use pytest.skip
70 ];
71
72 pythonImportsCheck = [
73 "cut_cross_entropy"
74 ];
75
76 meta = {
77 description = "Memory-efficient cross-entropy loss implementation using Cut Cross-Entropy (CCE)";
78 homepage = "https://github.com/apple/ml-cross-entropy";
79 license = lib.licenses.aml;
80 maintainers = with lib.maintainers; [ hoh ];
81 };
82}