1import os
2import re
3import signal
4import sys
5import tempfile
6import threading
7import traceback
8from collections.abc import Callable, Iterator
9from contextlib import AbstractContextManager, contextmanager
10from pathlib import Path
11from typing import Any
12from unittest import TestCase
13
14from colorama import Style
15
16from test_driver.debug import DebugAbstract, DebugNop
17from test_driver.errors import MachineError, RequestedAssertionFailed
18from test_driver.logger import AbstractLogger
19from test_driver.machine import Machine, NixStartScript, retry
20from test_driver.polling_condition import PollingCondition
21from test_driver.vlan import VLan
22
23SENTINEL = object()
24
25
26class AssertionTester(TestCase):
27 """
28 Subclass of `unittest.TestCase` which is used in the
29 `testScript` to perform assertions.
30
31 It throws a custom exception whose parent class
32 gets special treatment in the logs.
33 """
34
35 failureException = RequestedAssertionFailed
36
37
38def get_tmp_dir() -> Path:
39 """Returns a temporary directory that is defined by TMPDIR, TEMP, TMP or CWD
40 Raises an exception in case the retrieved temporary directory is not writeable
41 See https://docs.python.org/3/library/tempfile.html#tempfile.gettempdir
42 """
43 tmp_dir = Path(os.environ.get("XDG_RUNTIME_DIR", tempfile.gettempdir()))
44 tmp_dir.mkdir(mode=0o700, exist_ok=True)
45 if not tmp_dir.is_dir():
46 raise NotADirectoryError(
47 f"The directory defined by XDG_RUNTIME_DIR, TMPDIR, TEMP, TMP or CWD: {tmp_dir} is not a directory"
48 )
49 if not os.access(tmp_dir, os.W_OK):
50 raise PermissionError(
51 f"The directory defined by XDG_RUNTIME_DIR, TMPDIR, TEMP, TMP, or CWD: {tmp_dir} is not writeable"
52 )
53 return tmp_dir
54
55
56def pythonize_name(name: str) -> str:
57 return re.sub(r"^[^A-Za-z_]|[^A-Za-z0-9_]", "_", name)
58
59
60class Driver:
61 """A handle to the driver that sets up the environment
62 and runs the tests"""
63
64 tests: str
65 vlans: list[VLan]
66 machines: list[Machine]
67 polling_conditions: list[PollingCondition]
68 global_timeout: int
69 race_timer: threading.Timer
70 logger: AbstractLogger
71 debug: DebugAbstract
72
73 def __init__(
74 self,
75 start_scripts: list[str],
76 vlans: list[int],
77 tests: str,
78 out_dir: Path,
79 logger: AbstractLogger,
80 keep_vm_state: bool = False,
81 global_timeout: int = 24 * 60 * 60 * 7,
82 debug: DebugAbstract = DebugNop(),
83 ):
84 self.tests = tests
85 self.out_dir = out_dir
86 self.global_timeout = global_timeout
87 self.race_timer = threading.Timer(global_timeout, self.terminate_test)
88 self.logger = logger
89 self.debug = debug
90
91 tmp_dir = get_tmp_dir()
92
93 with self.logger.nested("start all VLans"):
94 vlans = list(set(vlans))
95 self.vlans = [VLan(nr, tmp_dir, self.logger) for nr in vlans]
96
97 def cmd(scripts: list[str]) -> Iterator[NixStartScript]:
98 for s in scripts:
99 yield NixStartScript(s)
100
101 self.polling_conditions = []
102
103 self.machines = [
104 Machine(
105 start_command=cmd,
106 keep_vm_state=keep_vm_state,
107 name=cmd.machine_name,
108 tmp_dir=tmp_dir,
109 callbacks=[self.check_polling_conditions],
110 out_dir=self.out_dir,
111 logger=self.logger,
112 )
113 for cmd in cmd(start_scripts)
114 ]
115
116 def __enter__(self) -> "Driver":
117 return self
118
119 def __exit__(self, *_: Any) -> None:
120 with self.logger.nested("cleanup"):
121 self.race_timer.cancel()
122 for machine in self.machines:
123 try:
124 machine.release()
125 except Exception as e:
126 self.logger.error(f"Error during cleanup of {machine.name}: {e}")
127
128 for vlan in self.vlans:
129 try:
130 vlan.stop()
131 except Exception as e:
132 self.logger.error(f"Error during cleanup of vlan{vlan.nr}: {e}")
133
134 def subtest(self, name: str) -> Iterator[None]:
135 """Group logs under a given test name"""
136 with self.logger.subtest(name):
137 try:
138 yield
139 except Exception as e:
140 self.logger.log_test_error(f'Test "{name}" failed with error: "{e}"')
141 raise e
142
143 def test_symbols(self) -> dict[str, Any]:
144 @contextmanager
145 def subtest(name: str) -> Iterator[None]:
146 return self.subtest(name)
147
148 general_symbols = dict(
149 start_all=self.start_all,
150 test_script=self.test_script,
151 machines=self.machines,
152 vlans=self.vlans,
153 driver=self,
154 log=self.logger,
155 os=os,
156 create_machine=self.create_machine,
157 subtest=subtest,
158 run_tests=self.run_tests,
159 join_all=self.join_all,
160 retry=retry,
161 serial_stdout_off=self.serial_stdout_off,
162 serial_stdout_on=self.serial_stdout_on,
163 polling_condition=self.polling_condition,
164 Machine=Machine, # for typing
165 t=AssertionTester(),
166 debug=self.debug,
167 )
168 machine_symbols = {pythonize_name(m.name): m for m in self.machines}
169 # If there's exactly one machine, make it available under the name
170 # "machine", even if it's not called that.
171 if len(self.machines) == 1:
172 (machine_symbols["machine"],) = self.machines
173 vlan_symbols = {
174 f"vlan{v.nr}": self.vlans[idx] for idx, v in enumerate(self.vlans)
175 }
176 print(
177 "additionally exposed symbols:\n "
178 + ", ".join(map(lambda m: m.name, self.machines))
179 + ",\n "
180 + ", ".join(map(lambda v: f"vlan{v.nr}", self.vlans))
181 + ",\n "
182 + ", ".join(list(general_symbols.keys()))
183 )
184 return {**general_symbols, **machine_symbols, **vlan_symbols}
185
186 def dump_machine_ssh(self, offset: int) -> None:
187 print("SSH backdoor enabled, the machines can be accessed like this:")
188 print(
189 f"{Style.BRIGHT}Note:{Style.RESET_ALL} this requires {Style.BRIGHT}systemd-ssh-proxy(1){Style.RESET_ALL} to be enabled (default on NixOS 25.05 and newer)."
190 )
191 names = [machine.name for machine in self.machines]
192 longest_name = len(max(names, key=len))
193 for num, name in enumerate(names, start=offset + 1):
194 spaces = " " * (longest_name - len(name) + 2)
195 print(
196 f" {name}:{spaces}{Style.BRIGHT}ssh -o User=root vsock/{num}{Style.RESET_ALL}"
197 )
198
199 def test_script(self) -> None:
200 """Run the test script"""
201 with self.logger.nested("run the VM test script"):
202 symbols = self.test_symbols() # call eagerly
203 try:
204 exec(self.tests, symbols, None)
205 except MachineError:
206 for line in traceback.format_exc().splitlines():
207 self.logger.log_test_error(line)
208 sys.exit(1)
209 except RequestedAssertionFailed:
210 exc_type, exc, tb = sys.exc_info()
211 # We manually print the stack frames, keeping only the ones from the test script
212 # (note: because the script is not a real file, the frame filename is `<string>`)
213 filtered = [
214 frame
215 for frame in traceback.extract_tb(tb)
216 if frame.filename == "<string>"
217 ]
218
219 self.logger.log_test_error("Traceback (most recent call last):")
220
221 code = self.tests.splitlines()
222 for frame, line in zip(filtered, traceback.format_list(filtered)):
223 self.logger.log_test_error(line.rstrip())
224 if lineno := frame.lineno:
225 self.logger.log_test_error(f" {code[lineno - 1].strip()}")
226
227 self.logger.log_test_error("") # blank line for readability
228 exc_prefix = exc_type.__name__ if exc_type is not None else "Error"
229 for line in f"{exc_prefix}: {exc}".splitlines():
230 self.logger.log_test_error(line)
231
232 self.debug.breakpoint()
233
234 sys.exit(1)
235
236 except Exception:
237 self.debug.breakpoint()
238 raise
239
240 def run_tests(self) -> None:
241 """Run the test script (for non-interactive test runs)"""
242 self.logger.info(
243 f"Test will time out and terminate in {self.global_timeout} seconds"
244 )
245 self.race_timer.start()
246 self.test_script()
247 # TODO: Collect coverage data
248 for machine in self.machines:
249 if machine.is_up():
250 machine.execute("sync")
251
252 def start_all(self) -> None:
253 """Start all machines"""
254 with self.logger.nested("start all VMs"):
255 for machine in self.machines:
256 machine.start()
257
258 def join_all(self) -> None:
259 """Wait for all machines to shut down"""
260 with self.logger.nested("wait for all VMs to finish"):
261 for machine in self.machines:
262 machine.wait_for_shutdown()
263 self.race_timer.cancel()
264
265 def terminate_test(self) -> None:
266 # This will be usually running in another thread than
267 # the thread actually executing the test script.
268 with self.logger.nested("timeout reached; test terminating..."):
269 for machine in self.machines:
270 machine.release()
271 # As we cannot `sys.exit` from another thread
272 # We can at least force the main thread to get SIGTERM'ed.
273 # This will prevent any user who caught all the exceptions
274 # to swallow them and prevent itself from terminating.
275 os.kill(os.getpid(), signal.SIGTERM)
276
277 def create_machine(
278 self,
279 start_command: str,
280 *,
281 name: str | None = None,
282 keep_vm_state: bool = False,
283 ) -> Machine:
284 tmp_dir = get_tmp_dir()
285
286 cmd = NixStartScript(start_command)
287 name = name or cmd.machine_name
288
289 return Machine(
290 tmp_dir=tmp_dir,
291 out_dir=self.out_dir,
292 start_command=cmd,
293 name=name,
294 keep_vm_state=keep_vm_state,
295 logger=self.logger,
296 )
297
298 def serial_stdout_on(self) -> None:
299 self.logger.print_serial_logs(True)
300
301 def serial_stdout_off(self) -> None:
302 self.logger.print_serial_logs(False)
303
304 def check_polling_conditions(self) -> None:
305 for condition in self.polling_conditions:
306 condition.maybe_raise()
307
308 def polling_condition(
309 self,
310 fun_: Callable | None = None,
311 *,
312 seconds_interval: float = 2.0,
313 description: str | None = None,
314 ) -> Callable[[Callable], AbstractContextManager] | AbstractContextManager:
315 driver = self
316
317 class Poll:
318 def __init__(self, fun: Callable):
319 self.condition = PollingCondition(
320 fun,
321 driver.logger,
322 seconds_interval,
323 description,
324 )
325
326 def __enter__(self) -> None:
327 driver.polling_conditions.append(self.condition)
328
329 def __exit__(self, a, b, c) -> None: # type: ignore
330 res = driver.polling_conditions.pop()
331 assert res is self.condition
332
333 def wait(self, timeout: int = 900) -> None:
334 def condition(last: bool) -> bool:
335 if last:
336 driver.logger.info(
337 f"Last chance for {self.condition.description}"
338 )
339 ret = self.condition.check(force=True)
340 if not ret and not last:
341 driver.logger.info(
342 f"({self.condition.description} failure not fatal yet)"
343 )
344 return ret
345
346 with driver.logger.nested(f"waiting for {self.condition.description}"):
347 retry(condition, timeout=timeout)
348
349 if fun_ is None:
350 return Poll
351 else:
352 return Poll(fun_)