1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5 fetchpatch,
6 pythonOlder,
7 torch,
8 torchvision,
9 pytestCheckHook,
10 transformers,
11}:
12
13buildPythonPackage rec {
14 pname = "torchinfo";
15 version = "1.8.0";
16 format = "setuptools";
17
18 disabled = pythonOlder "3.7";
19
20 src = fetchFromGitHub {
21 owner = "TylerYep";
22 repo = "torchinfo";
23 tag = "v${version}";
24 hash = "sha256-pPjg498aT8y4b4tqIzNxxKyobZX01u+66ScS/mee51Q=";
25 };
26
27 patches = [
28 (fetchpatch {
29 # Add support for Python 3.11 and pytorch 2.1
30 url = "https://github.com/TylerYep/torchinfo/commit/c74784c71c84e62bcf56664653b7f28d72a2ee0d.patch";
31 hash = "sha256-xSSqs0tuFpdMXUsoVv4sZLCeVnkK6pDDhX/Eobvn5mw=";
32 includes = [ "torchinfo/model_statistics.py" ];
33 })
34 ];
35
36 propagatedBuildInputs = [
37 torch
38 torchvision
39 ];
40
41 nativeCheckInputs = [
42 pytestCheckHook
43 transformers
44 ];
45
46 preCheck = ''
47 export HOME=$(mktemp -d)
48 '';
49
50 disabledTests = [
51 # Skip as it downloads pretrained weights (require network access)
52 "test_eval_order_doesnt_matter"
53 "test_flan_t5_small"
54 # AssertionError in output
55 "test_google"
56 # "addmm_impl_cpu_" not implemented for 'Half'
57 "test_input_size_half_precision"
58 ];
59
60 disabledTestPaths = [
61 # Test requires network access
62 "tests/torchinfo_xl_test.py"
63 ];
64
65 pythonImportsCheck = [ "torchinfo" ];
66
67 meta = with lib; {
68 description = "API to visualize pytorch models";
69 homepage = "https://github.com/TylerYep/torchinfo";
70 license = licenses.mit;
71 maintainers = with maintainers; [ petterstorvik ];
72 };
73}