at 23.11-pre 8.0 kB view raw
1from contextlib import contextmanager 2from pathlib import Path 3from typing import Any, Dict, Iterator, List, Union, Optional, Callable, ContextManager 4import os 5import re 6import tempfile 7 8from test_driver.logger import rootlog 9from test_driver.machine import Machine, NixStartScript, retry 10from test_driver.vlan import VLan 11from test_driver.polling_condition import PollingCondition 12 13 14def get_tmp_dir() -> Path: 15 """Returns a temporary directory that is defined by TMPDIR, TEMP, TMP or CWD 16 Raises an exception in case the retrieved temporary directory is not writeable 17 See https://docs.python.org/3/library/tempfile.html#tempfile.gettempdir 18 """ 19 tmp_dir = Path(tempfile.gettempdir()) 20 tmp_dir.mkdir(mode=0o700, exist_ok=True) 21 if not tmp_dir.is_dir(): 22 raise NotADirectoryError( 23 f"The directory defined by TMPDIR, TEMP, TMP or CWD: {tmp_dir} is not a directory" 24 ) 25 if not os.access(tmp_dir, os.W_OK): 26 raise PermissionError( 27 f"The directory defined by TMPDIR, TEMP, TMP, or CWD: {tmp_dir} is not writeable" 28 ) 29 return tmp_dir 30 31 32def pythonize_name(name: str) -> str: 33 return re.sub(r"^[^A-z_]|[^A-z0-9_]", "_", name) 34 35 36class Driver: 37 """A handle to the driver that sets up the environment 38 and runs the tests""" 39 40 tests: str 41 vlans: List[VLan] 42 machines: List[Machine] 43 polling_conditions: List[PollingCondition] 44 45 def __init__( 46 self, 47 start_scripts: List[str], 48 vlans: List[int], 49 tests: str, 50 out_dir: Path, 51 keep_vm_state: bool = False, 52 ): 53 self.tests = tests 54 self.out_dir = out_dir 55 56 tmp_dir = get_tmp_dir() 57 58 with rootlog.nested("start all VLans"): 59 vlans = list(set(vlans)) 60 self.vlans = [VLan(nr, tmp_dir) for nr in vlans] 61 62 def cmd(scripts: List[str]) -> Iterator[NixStartScript]: 63 for s in scripts: 64 yield NixStartScript(s) 65 66 self.polling_conditions = [] 67 68 self.machines = [ 69 Machine( 70 start_command=cmd, 71 keep_vm_state=keep_vm_state, 72 name=cmd.machine_name, 73 tmp_dir=tmp_dir, 74 callbacks=[self.check_polling_conditions], 75 out_dir=self.out_dir, 76 ) 77 for cmd in cmd(start_scripts) 78 ] 79 80 def __enter__(self) -> "Driver": 81 return self 82 83 def __exit__(self, *_: Any) -> None: 84 with rootlog.nested("cleanup"): 85 for machine in self.machines: 86 machine.release() 87 88 def subtest(self, name: str) -> Iterator[None]: 89 """Group logs under a given test name""" 90 with rootlog.nested("subtest: " + name): 91 try: 92 yield 93 return True 94 except Exception as e: 95 rootlog.error(f'Test "{name}" failed with error: "{e}"') 96 raise e 97 98 def test_symbols(self) -> Dict[str, Any]: 99 @contextmanager 100 def subtest(name: str) -> Iterator[None]: 101 return self.subtest(name) 102 103 general_symbols = dict( 104 start_all=self.start_all, 105 test_script=self.test_script, 106 machines=self.machines, 107 vlans=self.vlans, 108 driver=self, 109 log=rootlog, 110 os=os, 111 create_machine=self.create_machine, 112 subtest=subtest, 113 run_tests=self.run_tests, 114 join_all=self.join_all, 115 retry=retry, 116 serial_stdout_off=self.serial_stdout_off, 117 serial_stdout_on=self.serial_stdout_on, 118 polling_condition=self.polling_condition, 119 Machine=Machine, # for typing 120 ) 121 machine_symbols = {pythonize_name(m.name): m for m in self.machines} 122 # If there's exactly one machine, make it available under the name 123 # "machine", even if it's not called that. 124 if len(self.machines) == 1: 125 (machine_symbols["machine"],) = self.machines 126 vlan_symbols = { 127 f"vlan{v.nr}": self.vlans[idx] for idx, v in enumerate(self.vlans) 128 } 129 print( 130 "additionally exposed symbols:\n " 131 + ", ".join(map(lambda m: m.name, self.machines)) 132 + ",\n " 133 + ", ".join(map(lambda v: f"vlan{v.nr}", self.vlans)) 134 + ",\n " 135 + ", ".join(list(general_symbols.keys())) 136 ) 137 return {**general_symbols, **machine_symbols, **vlan_symbols} 138 139 def test_script(self) -> None: 140 """Run the test script""" 141 with rootlog.nested("run the VM test script"): 142 symbols = self.test_symbols() # call eagerly 143 exec(self.tests, symbols, None) 144 145 def run_tests(self) -> None: 146 """Run the test script (for non-interactive test runs)""" 147 self.test_script() 148 # TODO: Collect coverage data 149 for machine in self.machines: 150 if machine.is_up(): 151 machine.execute("sync") 152 153 def start_all(self) -> None: 154 """Start all machines""" 155 with rootlog.nested("start all VMs"): 156 for machine in self.machines: 157 machine.start() 158 159 def join_all(self) -> None: 160 """Wait for all machines to shut down""" 161 with rootlog.nested("wait for all VMs to finish"): 162 for machine in self.machines: 163 machine.wait_for_shutdown() 164 165 def create_machine(self, args: Dict[str, Any]) -> Machine: 166 rootlog.warning( 167 "Using legacy create_machine(), please instantiate the" 168 "Machine class directly, instead" 169 ) 170 171 tmp_dir = get_tmp_dir() 172 173 if args.get("startCommand"): 174 start_command: str = args.get("startCommand", "") 175 cmd = NixStartScript(start_command) 176 name = args.get("name", cmd.machine_name) 177 else: 178 cmd = Machine.create_startcommand(args) # type: ignore 179 name = args.get("name", "machine") 180 181 return Machine( 182 tmp_dir=tmp_dir, 183 out_dir=self.out_dir, 184 start_command=cmd, 185 name=name, 186 keep_vm_state=args.get("keep_vm_state", False), 187 ) 188 189 def serial_stdout_on(self) -> None: 190 rootlog._print_serial_logs = True 191 192 def serial_stdout_off(self) -> None: 193 rootlog._print_serial_logs = False 194 195 def check_polling_conditions(self) -> None: 196 for condition in self.polling_conditions: 197 condition.maybe_raise() 198 199 def polling_condition( 200 self, 201 fun_: Optional[Callable] = None, 202 *, 203 seconds_interval: float = 2.0, 204 description: Optional[str] = None, 205 ) -> Union[Callable[[Callable], ContextManager], ContextManager]: 206 driver = self 207 208 class Poll: 209 def __init__(self, fun: Callable): 210 self.condition = PollingCondition( 211 fun, 212 seconds_interval, 213 description, 214 ) 215 216 def __enter__(self) -> None: 217 driver.polling_conditions.append(self.condition) 218 219 def __exit__(self, a, b, c) -> None: # type: ignore 220 res = driver.polling_conditions.pop() 221 assert res is self.condition 222 223 def wait(self, timeout: int = 900) -> None: 224 def condition(last: bool) -> bool: 225 if last: 226 rootlog.info(f"Last chance for {self.condition.description}") 227 ret = self.condition.check(force=True) 228 if not ret and not last: 229 rootlog.info( 230 f"({self.condition.description} failure not fatal yet)" 231 ) 232 return ret 233 234 with rootlog.nested(f"waiting for {self.condition.description}"): 235 retry(condition, timeout=timeout) 236 237 if fun_ is None: 238 return Poll 239 else: 240 return Poll(fun_)