1{
2 lib,
3 stdenv,
4 buildPythonPackage,
5 fetchFromGitHub,
6
7 # build-system
8 setuptools,
9 cython,
10 versioneer,
11
12 # dependencies
13 cons,
14 etuples,
15 filelock,
16 logical-unification,
17 minikanren,
18 numpy,
19 scipy,
20
21 # tests
22 jax,
23 jaxlib,
24 numba,
25 pytest-benchmark,
26 pytest-mock,
27 pytestCheckHook,
28 tensorflow-probability,
29 writableTmpDirAsHomeHook,
30
31 nix-update-script,
32}:
33
34buildPythonPackage rec {
35 pname = "pytensor";
36 version = "2.33.0";
37 pyproject = true;
38
39 src = fetchFromGitHub {
40 owner = "pymc-devs";
41 repo = "pytensor";
42 tag = "rel-${version}";
43 postFetch = ''
44 sed -i 's/git_refnames = "[^"]*"/git_refnames = " (tag: ${src.tag})"/' $out/pytensor/_version.py
45 '';
46 hash = "sha256-ngdjFqUJnJU+krNJwAwOpz1hJzDYvyKjuR/Ti/V+B3w=";
47 };
48
49 build-system = [
50 setuptools
51 cython
52 versioneer
53 ];
54
55 dependencies = [
56 cons
57 etuples
58 filelock
59 logical-unification
60 minikanren
61 numpy
62 scipy
63 ];
64
65 nativeCheckInputs = [
66 jax
67 jaxlib
68 numba
69 pytest-benchmark
70 pytest-mock
71 pytestCheckHook
72 tensorflow-probability
73 writableTmpDirAsHomeHook
74 ];
75
76 pytestFlags = [ "--benchmark-disable" ];
77
78 pythonImportsCheck = [ "pytensor" ];
79
80 # Ensure that the installed package is used instead of the source files from the current workdir
81 preCheck = ''
82 rm -rf pytensor
83 '';
84
85 disabledTests = lib.optionals stdenv.hostPlatform.isDarwin [
86 # Numerical assertion error
87 # tests.unittest_tools.WrongValue: WrongValue
88 "test_op_sd"
89 "test_op_ss"
90
91 # pytensor.link.c.exceptions.CompileError: Compilation failed (return status=1)
92 "OpFromGraph"
93 "add"
94 "cls_ofg1"
95 "direct"
96 "multiply"
97 "test_AddDS"
98 "test_AddSD"
99 "test_AddSS"
100 "test_MulDS"
101 "test_MulSD"
102 "test_MulSS"
103 "test_NoOutputFromInplace"
104 "test_OpFromGraph"
105 "test_adv_sub1_sparse_grad"
106 "test_alloc"
107 "test_binary"
108 "test_borrow_input"
109 "test_borrow_output"
110 "test_cache_race_condition"
111 "test_check_for_aliased_inputs"
112 "test_clinker_literal_cache"
113 "test_csm_grad"
114 "test_csm_unsorted"
115 "test_csr_dense_grad"
116 "test_debugprint"
117 "test_ellipsis_einsum"
118 "test_empty_elemwise"
119 "test_flatten"
120 "test_fprop"
121 "test_get_item_list_grad"
122 "test_grad"
123 "test_infer_shape"
124 "test_jax_pad"
125 "test_kron"
126 "test_masked_input"
127 "test_max"
128 "test_modes"
129 "test_mul_s_v_grad"
130 "test_multiple_outputs"
131 "test_nnet"
132 "test_not_inplace"
133 "test_numba_Cholesky_grad"
134 "test_numba_pad"
135 "test_optimizations_preserved"
136 "test_overided_function"
137 "test_potential_output_aliasing_induced_by_updates"
138 "test_profiling"
139 "test_rebuild_strict"
140 "test_runtime_broadcast_c"
141 "test_scan_err1"
142 "test_scan_err2"
143 "test_shared"
144 "test_size_implied_by_broadcasted_parameters"
145 "test_solve_triangular_grad"
146 "test_structured_add_s_v_grad"
147 "test_structureddot_csc_grad"
148 "test_structureddot_csr_grad"
149 "test_sum"
150 "test_swap_SharedVariable_with_given"
151 "test_test_value_op"
152 "test_unary"
153 "test_unbroadcast"
154 "test_update_equiv"
155 "test_update_same"
156 ];
157
158 disabledTestPaths = [
159 # Don't run the most compute-intense tests
160 "tests/scan/"
161 "tests/tensor/"
162 "tests/sparse/sandbox/"
163 ];
164
165 passthru.updateScript = nix-update-script {
166 extraArgs = [
167 "--version-regex"
168 "rel-(.+)"
169 ];
170 };
171
172 meta = {
173 description = "Python library to define, optimize, and efficiently evaluate mathematical expressions involving multi-dimensional arrays";
174 mainProgram = "pytensor-cache";
175 homepage = "https://github.com/pymc-devs/pytensor";
176 changelog = "https://github.com/pymc-devs/pytensor/releases/tag/rel-${src.tag}";
177 license = lib.licenses.bsd3;
178 maintainers = with lib.maintainers; [
179 bcdarwin
180 ferrine
181 ];
182 };
183}