at master 1.6 kB view raw
1{ 2 lib, 3 buildPythonPackage, 4 fetchFromGitHub, 5 fetchpatch, 6 pythonOlder, 7 torch, 8 torchvision, 9 pytestCheckHook, 10 transformers, 11}: 12 13buildPythonPackage rec { 14 pname = "torchinfo"; 15 version = "1.8.0"; 16 format = "setuptools"; 17 18 disabled = pythonOlder "3.7"; 19 20 src = fetchFromGitHub { 21 owner = "TylerYep"; 22 repo = "torchinfo"; 23 tag = "v${version}"; 24 hash = "sha256-pPjg498aT8y4b4tqIzNxxKyobZX01u+66ScS/mee51Q="; 25 }; 26 27 patches = [ 28 (fetchpatch { 29 # Add support for Python 3.11 and pytorch 2.1 30 url = "https://github.com/TylerYep/torchinfo/commit/c74784c71c84e62bcf56664653b7f28d72a2ee0d.patch"; 31 hash = "sha256-xSSqs0tuFpdMXUsoVv4sZLCeVnkK6pDDhX/Eobvn5mw="; 32 includes = [ "torchinfo/model_statistics.py" ]; 33 }) 34 ]; 35 36 propagatedBuildInputs = [ 37 torch 38 torchvision 39 ]; 40 41 nativeCheckInputs = [ 42 pytestCheckHook 43 transformers 44 ]; 45 46 preCheck = '' 47 export HOME=$(mktemp -d) 48 ''; 49 50 disabledTests = [ 51 # Skip as it downloads pretrained weights (require network access) 52 "test_eval_order_doesnt_matter" 53 "test_flan_t5_small" 54 # AssertionError in output 55 "test_google" 56 # "addmm_impl_cpu_" not implemented for 'Half' 57 "test_input_size_half_precision" 58 ]; 59 60 disabledTestPaths = [ 61 # Test requires network access 62 "tests/torchinfo_xl_test.py" 63 ]; 64 65 pythonImportsCheck = [ "torchinfo" ]; 66 67 meta = with lib; { 68 description = "API to visualize pytorch models"; 69 homepage = "https://github.com/TylerYep/torchinfo"; 70 license = licenses.mit; 71 maintainers = with maintainers; [ petterstorvik ]; 72 }; 73}