at 24.11-pre 10 kB view raw
1import os 2import re 3import signal 4import tempfile 5import threading 6from contextlib import contextmanager 7from pathlib import Path 8from typing import Any, Callable, ContextManager, Dict, Iterator, List, Optional, Union 9 10from colorama import Fore, Style 11 12from test_driver.logger import AbstractLogger 13from test_driver.machine import Machine, NixStartScript, retry 14from test_driver.polling_condition import PollingCondition 15from test_driver.vlan import VLan 16 17SENTINEL = object() 18 19 20def get_tmp_dir() -> Path: 21 """Returns a temporary directory that is defined by TMPDIR, TEMP, TMP or CWD 22 Raises an exception in case the retrieved temporary directory is not writeable 23 See https://docs.python.org/3/library/tempfile.html#tempfile.gettempdir 24 """ 25 tmp_dir = Path(tempfile.gettempdir()) 26 tmp_dir.mkdir(mode=0o700, exist_ok=True) 27 if not tmp_dir.is_dir(): 28 raise NotADirectoryError( 29 f"The directory defined by TMPDIR, TEMP, TMP or CWD: {tmp_dir} is not a directory" 30 ) 31 if not os.access(tmp_dir, os.W_OK): 32 raise PermissionError( 33 f"The directory defined by TMPDIR, TEMP, TMP, or CWD: {tmp_dir} is not writeable" 34 ) 35 return tmp_dir 36 37 38def pythonize_name(name: str) -> str: 39 return re.sub(r"^[^A-z_]|[^A-z0-9_]", "_", name) 40 41 42class Driver: 43 """A handle to the driver that sets up the environment 44 and runs the tests""" 45 46 tests: str 47 vlans: List[VLan] 48 machines: List[Machine] 49 polling_conditions: List[PollingCondition] 50 global_timeout: int 51 race_timer: threading.Timer 52 logger: AbstractLogger 53 54 def __init__( 55 self, 56 start_scripts: List[str], 57 vlans: List[int], 58 tests: str, 59 out_dir: Path, 60 logger: AbstractLogger, 61 keep_vm_state: bool = False, 62 global_timeout: int = 24 * 60 * 60 * 7, 63 ): 64 self.tests = tests 65 self.out_dir = out_dir 66 self.global_timeout = global_timeout 67 self.race_timer = threading.Timer(global_timeout, self.terminate_test) 68 self.logger = logger 69 70 tmp_dir = get_tmp_dir() 71 72 with self.logger.nested("start all VLans"): 73 vlans = list(set(vlans)) 74 self.vlans = [VLan(nr, tmp_dir, self.logger) for nr in vlans] 75 76 def cmd(scripts: List[str]) -> Iterator[NixStartScript]: 77 for s in scripts: 78 yield NixStartScript(s) 79 80 self.polling_conditions = [] 81 82 self.machines = [ 83 Machine( 84 start_command=cmd, 85 keep_vm_state=keep_vm_state, 86 name=cmd.machine_name, 87 tmp_dir=tmp_dir, 88 callbacks=[self.check_polling_conditions], 89 out_dir=self.out_dir, 90 logger=self.logger, 91 ) 92 for cmd in cmd(start_scripts) 93 ] 94 95 def __enter__(self) -> "Driver": 96 return self 97 98 def __exit__(self, *_: Any) -> None: 99 with self.logger.nested("cleanup"): 100 self.race_timer.cancel() 101 for machine in self.machines: 102 machine.release() 103 104 def subtest(self, name: str) -> Iterator[None]: 105 """Group logs under a given test name""" 106 with self.logger.subtest(name): 107 try: 108 yield 109 return True 110 except Exception as e: 111 self.logger.error(f'Test "{name}" failed with error: "{e}"') 112 raise e 113 114 def test_symbols(self) -> Dict[str, Any]: 115 @contextmanager 116 def subtest(name: str) -> Iterator[None]: 117 return self.subtest(name) 118 119 general_symbols = dict( 120 start_all=self.start_all, 121 test_script=self.test_script, 122 machines=self.machines, 123 vlans=self.vlans, 124 driver=self, 125 log=self.logger, 126 os=os, 127 create_machine=self.create_machine, 128 subtest=subtest, 129 run_tests=self.run_tests, 130 join_all=self.join_all, 131 retry=retry, 132 serial_stdout_off=self.serial_stdout_off, 133 serial_stdout_on=self.serial_stdout_on, 134 polling_condition=self.polling_condition, 135 Machine=Machine, # for typing 136 ) 137 machine_symbols = {pythonize_name(m.name): m for m in self.machines} 138 # If there's exactly one machine, make it available under the name 139 # "machine", even if it's not called that. 140 if len(self.machines) == 1: 141 (machine_symbols["machine"],) = self.machines 142 vlan_symbols = { 143 f"vlan{v.nr}": self.vlans[idx] for idx, v in enumerate(self.vlans) 144 } 145 print( 146 "additionally exposed symbols:\n " 147 + ", ".join(map(lambda m: m.name, self.machines)) 148 + ",\n " 149 + ", ".join(map(lambda v: f"vlan{v.nr}", self.vlans)) 150 + ",\n " 151 + ", ".join(list(general_symbols.keys())) 152 ) 153 return {**general_symbols, **machine_symbols, **vlan_symbols} 154 155 def test_script(self) -> None: 156 """Run the test script""" 157 with self.logger.nested("run the VM test script"): 158 symbols = self.test_symbols() # call eagerly 159 exec(self.tests, symbols, None) 160 161 def run_tests(self) -> None: 162 """Run the test script (for non-interactive test runs)""" 163 self.logger.info( 164 f"Test will time out and terminate in {self.global_timeout} seconds" 165 ) 166 self.race_timer.start() 167 self.test_script() 168 # TODO: Collect coverage data 169 for machine in self.machines: 170 if machine.is_up(): 171 machine.execute("sync") 172 173 def start_all(self) -> None: 174 """Start all machines""" 175 with self.logger.nested("start all VMs"): 176 for machine in self.machines: 177 machine.start() 178 179 def join_all(self) -> None: 180 """Wait for all machines to shut down""" 181 with self.logger.nested("wait for all VMs to finish"): 182 for machine in self.machines: 183 machine.wait_for_shutdown() 184 self.race_timer.cancel() 185 186 def terminate_test(self) -> None: 187 # This will be usually running in another thread than 188 # the thread actually executing the test script. 189 with self.logger.nested("timeout reached; test terminating..."): 190 for machine in self.machines: 191 machine.release() 192 # As we cannot `sys.exit` from another thread 193 # We can at least force the main thread to get SIGTERM'ed. 194 # This will prevent any user who caught all the exceptions 195 # to swallow them and prevent itself from terminating. 196 os.kill(os.getpid(), signal.SIGTERM) 197 198 def create_machine( 199 self, 200 start_command: str | dict, 201 *, 202 name: Optional[str] = None, 203 keep_vm_state: bool = False, 204 ) -> Machine: 205 # Legacy args handling 206 # FIXME: remove after 24.05 207 if isinstance(start_command, dict): 208 if name is not None or keep_vm_state: 209 raise TypeError( 210 "Dictionary passed to create_machine must be the only argument" 211 ) 212 213 args = start_command 214 start_command = args.pop("startCommand", SENTINEL) 215 216 if start_command is SENTINEL: 217 raise TypeError( 218 "Dictionary passed to create_machine must contain startCommand" 219 ) 220 221 if not isinstance(start_command, str): 222 raise TypeError( 223 f"startCommand must be a string, got: {repr(start_command)}" 224 ) 225 226 name = args.pop("name", None) 227 keep_vm_state = args.pop("keep_vm_state", False) 228 229 if args: 230 raise TypeError( 231 f"Unsupported arguments passed to create_machine: {args}" 232 ) 233 234 self.logger.warning( 235 Fore.YELLOW 236 + Style.BRIGHT 237 + "WARNING: Using create_machine with a single dictionary argument is deprecated and will be removed in NixOS 24.11" 238 + Style.RESET_ALL 239 ) 240 # End legacy args handling 241 242 tmp_dir = get_tmp_dir() 243 244 cmd = NixStartScript(start_command) 245 name = name or cmd.machine_name 246 247 return Machine( 248 tmp_dir=tmp_dir, 249 out_dir=self.out_dir, 250 start_command=cmd, 251 name=name, 252 keep_vm_state=keep_vm_state, 253 logger=self.logger, 254 ) 255 256 def serial_stdout_on(self) -> None: 257 self.logger.print_serial_logs(True) 258 259 def serial_stdout_off(self) -> None: 260 self.logger.print_serial_logs(False) 261 262 def check_polling_conditions(self) -> None: 263 for condition in self.polling_conditions: 264 condition.maybe_raise() 265 266 def polling_condition( 267 self, 268 fun_: Optional[Callable] = None, 269 *, 270 seconds_interval: float = 2.0, 271 description: Optional[str] = None, 272 ) -> Union[Callable[[Callable], ContextManager], ContextManager]: 273 driver = self 274 275 class Poll: 276 def __init__(self, fun: Callable): 277 self.condition = PollingCondition( 278 fun, 279 driver.logger, 280 seconds_interval, 281 description, 282 ) 283 284 def __enter__(self) -> None: 285 driver.polling_conditions.append(self.condition) 286 287 def __exit__(self, a, b, c) -> None: # type: ignore 288 res = driver.polling_conditions.pop() 289 assert res is self.condition 290 291 def wait(self, timeout: int = 900) -> None: 292 def condition(last: bool) -> bool: 293 if last: 294 driver.logger.info( 295 f"Last chance for {self.condition.description}" 296 ) 297 ret = self.condition.check(force=True) 298 if not ret and not last: 299 driver.logger.info( 300 f"({self.condition.description} failure not fatal yet)" 301 ) 302 return ret 303 304 with driver.logger.nested(f"waiting for {self.condition.description}"): 305 retry(condition, timeout=timeout) 306 307 if fun_ is None: 308 return Poll 309 else: 310 return Poll(fun_)