1import os 2import re 3import signal 4import sys 5import tempfile 6import threading 7import traceback 8from collections.abc import Callable, Iterator 9from contextlib import AbstractContextManager, contextmanager 10from pathlib import Path 11from typing import Any 12from unittest import TestCase 13 14from colorama import Style 15 16from test_driver.debug import DebugAbstract, DebugNop 17from test_driver.errors import MachineError, RequestedAssertionFailed 18from test_driver.logger import AbstractLogger 19from test_driver.machine import Machine, NixStartScript, retry 20from test_driver.polling_condition import PollingCondition 21from test_driver.vlan import VLan 22 23SENTINEL = object() 24 25 26class AssertionTester(TestCase): 27 """ 28 Subclass of `unittest.TestCase` which is used in the 29 `testScript` to perform assertions. 30 31 It throws a custom exception whose parent class 32 gets special treatment in the logs. 33 """ 34 35 failureException = RequestedAssertionFailed 36 37 38def get_tmp_dir() -> Path: 39 """Returns a temporary directory that is defined by TMPDIR, TEMP, TMP or CWD 40 Raises an exception in case the retrieved temporary directory is not writeable 41 See https://docs.python.org/3/library/tempfile.html#tempfile.gettempdir 42 """ 43 tmp_dir = Path(os.environ.get("XDG_RUNTIME_DIR", tempfile.gettempdir())) 44 tmp_dir.mkdir(mode=0o700, exist_ok=True) 45 if not tmp_dir.is_dir(): 46 raise NotADirectoryError( 47 f"The directory defined by XDG_RUNTIME_DIR, TMPDIR, TEMP, TMP or CWD: {tmp_dir} is not a directory" 48 ) 49 if not os.access(tmp_dir, os.W_OK): 50 raise PermissionError( 51 f"The directory defined by XDG_RUNTIME_DIR, TMPDIR, TEMP, TMP, or CWD: {tmp_dir} is not writeable" 52 ) 53 return tmp_dir 54 55 56def pythonize_name(name: str) -> str: 57 return re.sub(r"^[^A-Za-z_]|[^A-Za-z0-9_]", "_", name) 58 59 60class Driver: 61 """A handle to the driver that sets up the environment 62 and runs the tests""" 63 64 tests: str 65 vlans: list[VLan] 66 machines: list[Machine] 67 polling_conditions: list[PollingCondition] 68 global_timeout: int 69 race_timer: threading.Timer 70 logger: AbstractLogger 71 debug: DebugAbstract 72 73 def __init__( 74 self, 75 start_scripts: list[str], 76 vlans: list[int], 77 tests: str, 78 out_dir: Path, 79 logger: AbstractLogger, 80 keep_vm_state: bool = False, 81 global_timeout: int = 24 * 60 * 60 * 7, 82 debug: DebugAbstract = DebugNop(), 83 ): 84 self.tests = tests 85 self.out_dir = out_dir 86 self.global_timeout = global_timeout 87 self.race_timer = threading.Timer(global_timeout, self.terminate_test) 88 self.logger = logger 89 self.debug = debug 90 91 tmp_dir = get_tmp_dir() 92 93 with self.logger.nested("start all VLans"): 94 vlans = list(set(vlans)) 95 self.vlans = [VLan(nr, tmp_dir, self.logger) for nr in vlans] 96 97 def cmd(scripts: list[str]) -> Iterator[NixStartScript]: 98 for s in scripts: 99 yield NixStartScript(s) 100 101 self.polling_conditions = [] 102 103 self.machines = [ 104 Machine( 105 start_command=cmd, 106 keep_vm_state=keep_vm_state, 107 name=cmd.machine_name, 108 tmp_dir=tmp_dir, 109 callbacks=[self.check_polling_conditions], 110 out_dir=self.out_dir, 111 logger=self.logger, 112 ) 113 for cmd in cmd(start_scripts) 114 ] 115 116 def __enter__(self) -> "Driver": 117 return self 118 119 def __exit__(self, *_: Any) -> None: 120 with self.logger.nested("cleanup"): 121 self.race_timer.cancel() 122 for machine in self.machines: 123 try: 124 machine.release() 125 except Exception as e: 126 self.logger.error(f"Error during cleanup of {machine.name}: {e}") 127 128 for vlan in self.vlans: 129 try: 130 vlan.stop() 131 except Exception as e: 132 self.logger.error(f"Error during cleanup of vlan{vlan.nr}: {e}") 133 134 def subtest(self, name: str) -> Iterator[None]: 135 """Group logs under a given test name""" 136 with self.logger.subtest(name): 137 try: 138 yield 139 except Exception as e: 140 self.logger.log_test_error(f'Test "{name}" failed with error: "{e}"') 141 raise e 142 143 def test_symbols(self) -> dict[str, Any]: 144 @contextmanager 145 def subtest(name: str) -> Iterator[None]: 146 return self.subtest(name) 147 148 general_symbols = dict( 149 start_all=self.start_all, 150 test_script=self.test_script, 151 machines=self.machines, 152 vlans=self.vlans, 153 driver=self, 154 log=self.logger, 155 os=os, 156 create_machine=self.create_machine, 157 subtest=subtest, 158 run_tests=self.run_tests, 159 join_all=self.join_all, 160 retry=retry, 161 serial_stdout_off=self.serial_stdout_off, 162 serial_stdout_on=self.serial_stdout_on, 163 polling_condition=self.polling_condition, 164 Machine=Machine, # for typing 165 t=AssertionTester(), 166 debug=self.debug, 167 ) 168 machine_symbols = {pythonize_name(m.name): m for m in self.machines} 169 # If there's exactly one machine, make it available under the name 170 # "machine", even if it's not called that. 171 if len(self.machines) == 1: 172 (machine_symbols["machine"],) = self.machines 173 vlan_symbols = { 174 f"vlan{v.nr}": self.vlans[idx] for idx, v in enumerate(self.vlans) 175 } 176 print( 177 "additionally exposed symbols:\n " 178 + ", ".join(map(lambda m: m.name, self.machines)) 179 + ",\n " 180 + ", ".join(map(lambda v: f"vlan{v.nr}", self.vlans)) 181 + ",\n " 182 + ", ".join(list(general_symbols.keys())) 183 ) 184 return {**general_symbols, **machine_symbols, **vlan_symbols} 185 186 def dump_machine_ssh(self, offset: int) -> None: 187 print("SSH backdoor enabled, the machines can be accessed like this:") 188 print( 189 f"{Style.BRIGHT}Note:{Style.RESET_ALL} this requires {Style.BRIGHT}systemd-ssh-proxy(1){Style.RESET_ALL} to be enabled (default on NixOS 25.05 and newer)." 190 ) 191 names = [machine.name for machine in self.machines] 192 longest_name = len(max(names, key=len)) 193 for num, name in enumerate(names, start=offset + 1): 194 spaces = " " * (longest_name - len(name) + 2) 195 print( 196 f" {name}:{spaces}{Style.BRIGHT}ssh -o User=root vsock/{num}{Style.RESET_ALL}" 197 ) 198 199 def test_script(self) -> None: 200 """Run the test script""" 201 with self.logger.nested("run the VM test script"): 202 symbols = self.test_symbols() # call eagerly 203 try: 204 exec(self.tests, symbols, None) 205 except MachineError: 206 for line in traceback.format_exc().splitlines(): 207 self.logger.log_test_error(line) 208 sys.exit(1) 209 except RequestedAssertionFailed: 210 exc_type, exc, tb = sys.exc_info() 211 # We manually print the stack frames, keeping only the ones from the test script 212 # (note: because the script is not a real file, the frame filename is `<string>`) 213 filtered = [ 214 frame 215 for frame in traceback.extract_tb(tb) 216 if frame.filename == "<string>" 217 ] 218 219 self.logger.log_test_error("Traceback (most recent call last):") 220 221 code = self.tests.splitlines() 222 for frame, line in zip(filtered, traceback.format_list(filtered)): 223 self.logger.log_test_error(line.rstrip()) 224 if lineno := frame.lineno: 225 self.logger.log_test_error(f" {code[lineno - 1].strip()}") 226 227 self.logger.log_test_error("") # blank line for readability 228 exc_prefix = exc_type.__name__ if exc_type is not None else "Error" 229 for line in f"{exc_prefix}: {exc}".splitlines(): 230 self.logger.log_test_error(line) 231 232 self.debug.breakpoint() 233 234 sys.exit(1) 235 236 except Exception: 237 self.debug.breakpoint() 238 raise 239 240 def run_tests(self) -> None: 241 """Run the test script (for non-interactive test runs)""" 242 self.logger.info( 243 f"Test will time out and terminate in {self.global_timeout} seconds" 244 ) 245 self.race_timer.start() 246 self.test_script() 247 # TODO: Collect coverage data 248 for machine in self.machines: 249 if machine.is_up(): 250 machine.execute("sync") 251 252 def start_all(self) -> None: 253 """Start all machines""" 254 with self.logger.nested("start all VMs"): 255 for machine in self.machines: 256 machine.start() 257 258 def join_all(self) -> None: 259 """Wait for all machines to shut down""" 260 with self.logger.nested("wait for all VMs to finish"): 261 for machine in self.machines: 262 machine.wait_for_shutdown() 263 self.race_timer.cancel() 264 265 def terminate_test(self) -> None: 266 # This will be usually running in another thread than 267 # the thread actually executing the test script. 268 with self.logger.nested("timeout reached; test terminating..."): 269 for machine in self.machines: 270 machine.release() 271 # As we cannot `sys.exit` from another thread 272 # We can at least force the main thread to get SIGTERM'ed. 273 # This will prevent any user who caught all the exceptions 274 # to swallow them and prevent itself from terminating. 275 os.kill(os.getpid(), signal.SIGTERM) 276 277 def create_machine( 278 self, 279 start_command: str, 280 *, 281 name: str | None = None, 282 keep_vm_state: bool = False, 283 ) -> Machine: 284 tmp_dir = get_tmp_dir() 285 286 cmd = NixStartScript(start_command) 287 name = name or cmd.machine_name 288 289 return Machine( 290 tmp_dir=tmp_dir, 291 out_dir=self.out_dir, 292 start_command=cmd, 293 name=name, 294 keep_vm_state=keep_vm_state, 295 logger=self.logger, 296 ) 297 298 def serial_stdout_on(self) -> None: 299 self.logger.print_serial_logs(True) 300 301 def serial_stdout_off(self) -> None: 302 self.logger.print_serial_logs(False) 303 304 def check_polling_conditions(self) -> None: 305 for condition in self.polling_conditions: 306 condition.maybe_raise() 307 308 def polling_condition( 309 self, 310 fun_: Callable | None = None, 311 *, 312 seconds_interval: float = 2.0, 313 description: str | None = None, 314 ) -> Callable[[Callable], AbstractContextManager] | AbstractContextManager: 315 driver = self 316 317 class Poll: 318 def __init__(self, fun: Callable): 319 self.condition = PollingCondition( 320 fun, 321 driver.logger, 322 seconds_interval, 323 description, 324 ) 325 326 def __enter__(self) -> None: 327 driver.polling_conditions.append(self.condition) 328 329 def __exit__(self, a, b, c) -> None: # type: ignore 330 res = driver.polling_conditions.pop() 331 assert res is self.condition 332 333 def wait(self, timeout: int = 900) -> None: 334 def condition(last: bool) -> bool: 335 if last: 336 driver.logger.info( 337 f"Last chance for {self.condition.description}" 338 ) 339 ret = self.condition.check(force=True) 340 if not ret and not last: 341 driver.logger.info( 342 f"({self.condition.description} failure not fatal yet)" 343 ) 344 return ret 345 346 with driver.logger.nested(f"waiting for {self.condition.description}"): 347 retry(condition, timeout=timeout) 348 349 if fun_ is None: 350 return Poll 351 else: 352 return Poll(fun_)