at 23.11-beta 9.0 kB view raw
1import os 2import re 3import signal 4import tempfile 5import threading 6from contextlib import contextmanager 7from pathlib import Path 8from typing import Any, Callable, ContextManager, Dict, Iterator, List, Optional, Union 9 10from test_driver.logger import rootlog 11from test_driver.machine import Machine, NixStartScript, retry 12from test_driver.polling_condition import PollingCondition 13from test_driver.vlan import VLan 14 15 16def get_tmp_dir() -> Path: 17 """Returns a temporary directory that is defined by TMPDIR, TEMP, TMP or CWD 18 Raises an exception in case the retrieved temporary directory is not writeable 19 See https://docs.python.org/3/library/tempfile.html#tempfile.gettempdir 20 """ 21 tmp_dir = Path(tempfile.gettempdir()) 22 tmp_dir.mkdir(mode=0o700, exist_ok=True) 23 if not tmp_dir.is_dir(): 24 raise NotADirectoryError( 25 f"The directory defined by TMPDIR, TEMP, TMP or CWD: {tmp_dir} is not a directory" 26 ) 27 if not os.access(tmp_dir, os.W_OK): 28 raise PermissionError( 29 f"The directory defined by TMPDIR, TEMP, TMP, or CWD: {tmp_dir} is not writeable" 30 ) 31 return tmp_dir 32 33 34def pythonize_name(name: str) -> str: 35 return re.sub(r"^[^A-z_]|[^A-z0-9_]", "_", name) 36 37 38class Driver: 39 """A handle to the driver that sets up the environment 40 and runs the tests""" 41 42 tests: str 43 vlans: List[VLan] 44 machines: List[Machine] 45 polling_conditions: List[PollingCondition] 46 global_timeout: int 47 race_timer: threading.Timer 48 49 def __init__( 50 self, 51 start_scripts: List[str], 52 vlans: List[int], 53 tests: str, 54 out_dir: Path, 55 keep_vm_state: bool = False, 56 global_timeout: int = 24 * 60 * 60 * 7, 57 ): 58 self.tests = tests 59 self.out_dir = out_dir 60 self.global_timeout = global_timeout 61 self.race_timer = threading.Timer(global_timeout, self.terminate_test) 62 63 tmp_dir = get_tmp_dir() 64 65 with rootlog.nested("start all VLans"): 66 vlans = list(set(vlans)) 67 self.vlans = [VLan(nr, tmp_dir) for nr in vlans] 68 69 def cmd(scripts: List[str]) -> Iterator[NixStartScript]: 70 for s in scripts: 71 yield NixStartScript(s) 72 73 self.polling_conditions = [] 74 75 self.machines = [ 76 Machine( 77 start_command=cmd, 78 keep_vm_state=keep_vm_state, 79 name=cmd.machine_name, 80 tmp_dir=tmp_dir, 81 callbacks=[self.check_polling_conditions], 82 out_dir=self.out_dir, 83 ) 84 for cmd in cmd(start_scripts) 85 ] 86 87 def __enter__(self) -> "Driver": 88 return self 89 90 def __exit__(self, *_: Any) -> None: 91 with rootlog.nested("cleanup"): 92 self.race_timer.cancel() 93 for machine in self.machines: 94 machine.release() 95 96 def subtest(self, name: str) -> Iterator[None]: 97 """Group logs under a given test name""" 98 with rootlog.nested("subtest: " + name): 99 try: 100 yield 101 return True 102 except Exception as e: 103 rootlog.error(f'Test "{name}" failed with error: "{e}"') 104 raise e 105 106 def test_symbols(self) -> Dict[str, Any]: 107 @contextmanager 108 def subtest(name: str) -> Iterator[None]: 109 return self.subtest(name) 110 111 general_symbols = dict( 112 start_all=self.start_all, 113 test_script=self.test_script, 114 machines=self.machines, 115 vlans=self.vlans, 116 driver=self, 117 log=rootlog, 118 os=os, 119 create_machine=self.create_machine, 120 subtest=subtest, 121 run_tests=self.run_tests, 122 join_all=self.join_all, 123 retry=retry, 124 serial_stdout_off=self.serial_stdout_off, 125 serial_stdout_on=self.serial_stdout_on, 126 polling_condition=self.polling_condition, 127 Machine=Machine, # for typing 128 ) 129 machine_symbols = {pythonize_name(m.name): m for m in self.machines} 130 # If there's exactly one machine, make it available under the name 131 # "machine", even if it's not called that. 132 if len(self.machines) == 1: 133 (machine_symbols["machine"],) = self.machines 134 vlan_symbols = { 135 f"vlan{v.nr}": self.vlans[idx] for idx, v in enumerate(self.vlans) 136 } 137 print( 138 "additionally exposed symbols:\n " 139 + ", ".join(map(lambda m: m.name, self.machines)) 140 + ",\n " 141 + ", ".join(map(lambda v: f"vlan{v.nr}", self.vlans)) 142 + ",\n " 143 + ", ".join(list(general_symbols.keys())) 144 ) 145 return {**general_symbols, **machine_symbols, **vlan_symbols} 146 147 def test_script(self) -> None: 148 """Run the test script""" 149 with rootlog.nested("run the VM test script"): 150 symbols = self.test_symbols() # call eagerly 151 exec(self.tests, symbols, None) 152 153 def run_tests(self) -> None: 154 """Run the test script (for non-interactive test runs)""" 155 rootlog.info( 156 f"Test will time out and terminate in {self.global_timeout} seconds" 157 ) 158 self.race_timer.start() 159 self.test_script() 160 # TODO: Collect coverage data 161 for machine in self.machines: 162 if machine.is_up(): 163 machine.execute("sync") 164 165 def start_all(self) -> None: 166 """Start all machines""" 167 with rootlog.nested("start all VMs"): 168 for machine in self.machines: 169 machine.start() 170 171 def join_all(self) -> None: 172 """Wait for all machines to shut down""" 173 with rootlog.nested("wait for all VMs to finish"): 174 for machine in self.machines: 175 machine.wait_for_shutdown() 176 self.race_timer.cancel() 177 178 def terminate_test(self) -> None: 179 # This will be usually running in another thread than 180 # the thread actually executing the test script. 181 with rootlog.nested("timeout reached; test terminating..."): 182 for machine in self.machines: 183 machine.release() 184 # As we cannot `sys.exit` from another thread 185 # We can at least force the main thread to get SIGTERM'ed. 186 # This will prevent any user who caught all the exceptions 187 # to swallow them and prevent itself from terminating. 188 os.kill(os.getpid(), signal.SIGTERM) 189 190 def create_machine(self, args: Dict[str, Any]) -> Machine: 191 tmp_dir = get_tmp_dir() 192 193 if args.get("startCommand"): 194 start_command: str = args.get("startCommand", "") 195 cmd = NixStartScript(start_command) 196 name = args.get("name", cmd.machine_name) 197 else: 198 cmd = Machine.create_startcommand(args) # type: ignore 199 name = args.get("name", "machine") 200 201 return Machine( 202 tmp_dir=tmp_dir, 203 out_dir=self.out_dir, 204 start_command=cmd, 205 name=name, 206 keep_vm_state=args.get("keep_vm_state", False), 207 ) 208 209 def serial_stdout_on(self) -> None: 210 rootlog._print_serial_logs = True 211 212 def serial_stdout_off(self) -> None: 213 rootlog._print_serial_logs = False 214 215 def check_polling_conditions(self) -> None: 216 for condition in self.polling_conditions: 217 condition.maybe_raise() 218 219 def polling_condition( 220 self, 221 fun_: Optional[Callable] = None, 222 *, 223 seconds_interval: float = 2.0, 224 description: Optional[str] = None, 225 ) -> Union[Callable[[Callable], ContextManager], ContextManager]: 226 driver = self 227 228 class Poll: 229 def __init__(self, fun: Callable): 230 self.condition = PollingCondition( 231 fun, 232 seconds_interval, 233 description, 234 ) 235 236 def __enter__(self) -> None: 237 driver.polling_conditions.append(self.condition) 238 239 def __exit__(self, a, b, c) -> None: # type: ignore 240 res = driver.polling_conditions.pop() 241 assert res is self.condition 242 243 def wait(self, timeout: int = 900) -> None: 244 def condition(last: bool) -> bool: 245 if last: 246 rootlog.info(f"Last chance for {self.condition.description}") 247 ret = self.condition.check(force=True) 248 if not ret and not last: 249 rootlog.info( 250 f"({self.condition.description} failure not fatal yet)" 251 ) 252 return ret 253 254 with rootlog.nested(f"waiting for {self.condition.description}"): 255 retry(condition, timeout=timeout) 256 257 if fun_ is None: 258 return Poll 259 else: 260 return Poll(fun_)