at 23.05-pre 7.4 kB view raw
1from contextlib import contextmanager 2from pathlib import Path 3from typing import Any, Dict, Iterator, List, Union, Optional, Callable, ContextManager 4import os 5import tempfile 6 7from test_driver.logger import rootlog 8from test_driver.machine import Machine, NixStartScript, retry 9from test_driver.vlan import VLan 10from test_driver.polling_condition import PollingCondition 11 12 13def get_tmp_dir() -> Path: 14 """Returns a temporary directory that is defined by TMPDIR, TEMP, TMP or CWD 15 Raises an exception in case the retrieved temporary directory is not writeable 16 See https://docs.python.org/3/library/tempfile.html#tempfile.gettempdir 17 """ 18 tmp_dir = Path(tempfile.gettempdir()) 19 tmp_dir.mkdir(mode=0o700, exist_ok=True) 20 if not tmp_dir.is_dir(): 21 raise NotADirectoryError( 22 "The directory defined by TMPDIR, TEMP, TMP or CWD: {0} is not a directory".format( 23 tmp_dir 24 ) 25 ) 26 if not os.access(tmp_dir, os.W_OK): 27 raise PermissionError( 28 "The directory defined by TMPDIR, TEMP, TMP, or CWD: {0} is not writeable".format( 29 tmp_dir 30 ) 31 ) 32 return tmp_dir 33 34 35class Driver: 36 """A handle to the driver that sets up the environment 37 and runs the tests""" 38 39 tests: str 40 vlans: List[VLan] 41 machines: List[Machine] 42 polling_conditions: List[PollingCondition] 43 44 def __init__( 45 self, 46 start_scripts: List[str], 47 vlans: List[int], 48 tests: str, 49 out_dir: Path, 50 keep_vm_state: bool = False, 51 ): 52 self.tests = tests 53 self.out_dir = out_dir 54 55 tmp_dir = get_tmp_dir() 56 57 with rootlog.nested("start all VLans"): 58 vlans = list(set(vlans)) 59 self.vlans = [VLan(nr, tmp_dir) for nr in vlans] 60 61 def cmd(scripts: List[str]) -> Iterator[NixStartScript]: 62 for s in scripts: 63 yield NixStartScript(s) 64 65 self.polling_conditions = [] 66 67 self.machines = [ 68 Machine( 69 start_command=cmd, 70 keep_vm_state=keep_vm_state, 71 name=cmd.machine_name, 72 tmp_dir=tmp_dir, 73 callbacks=[self.check_polling_conditions], 74 out_dir=self.out_dir, 75 ) 76 for cmd in cmd(start_scripts) 77 ] 78 79 def __enter__(self) -> "Driver": 80 return self 81 82 def __exit__(self, *_: Any) -> None: 83 with rootlog.nested("cleanup"): 84 for machine in self.machines: 85 machine.release() 86 87 def subtest(self, name: str) -> Iterator[None]: 88 """Group logs under a given test name""" 89 with rootlog.nested("subtest: " + name): 90 try: 91 yield 92 return True 93 except Exception as e: 94 rootlog.error(f'Test "{name}" failed with error: "{e}"') 95 raise e 96 97 def test_symbols(self) -> Dict[str, Any]: 98 @contextmanager 99 def subtest(name: str) -> Iterator[None]: 100 return self.subtest(name) 101 102 general_symbols = dict( 103 start_all=self.start_all, 104 test_script=self.test_script, 105 machines=self.machines, 106 vlans=self.vlans, 107 driver=self, 108 log=rootlog, 109 os=os, 110 create_machine=self.create_machine, 111 subtest=subtest, 112 run_tests=self.run_tests, 113 join_all=self.join_all, 114 retry=retry, 115 serial_stdout_off=self.serial_stdout_off, 116 serial_stdout_on=self.serial_stdout_on, 117 polling_condition=self.polling_condition, 118 Machine=Machine, # for typing 119 ) 120 machine_symbols = {m.name: m for m in self.machines} 121 # If there's exactly one machine, make it available under the name 122 # "machine", even if it's not called that. 123 if len(self.machines) == 1: 124 (machine_symbols["machine"],) = self.machines 125 vlan_symbols = { 126 f"vlan{v.nr}": self.vlans[idx] for idx, v in enumerate(self.vlans) 127 } 128 print( 129 "additionally exposed symbols:\n " 130 + ", ".join(map(lambda m: m.name, self.machines)) 131 + ",\n " 132 + ", ".join(map(lambda v: f"vlan{v.nr}", self.vlans)) 133 + ",\n " 134 + ", ".join(list(general_symbols.keys())) 135 ) 136 return {**general_symbols, **machine_symbols, **vlan_symbols} 137 138 def test_script(self) -> None: 139 """Run the test script""" 140 with rootlog.nested("run the VM test script"): 141 symbols = self.test_symbols() # call eagerly 142 exec(self.tests, symbols, None) 143 144 def run_tests(self) -> None: 145 """Run the test script (for non-interactive test runs)""" 146 self.test_script() 147 # TODO: Collect coverage data 148 for machine in self.machines: 149 if machine.is_up(): 150 machine.execute("sync") 151 152 def start_all(self) -> None: 153 """Start all machines""" 154 with rootlog.nested("start all VMs"): 155 for machine in self.machines: 156 machine.start() 157 158 def join_all(self) -> None: 159 """Wait for all machines to shut down""" 160 with rootlog.nested("wait for all VMs to finish"): 161 for machine in self.machines: 162 machine.wait_for_shutdown() 163 164 def create_machine(self, args: Dict[str, Any]) -> Machine: 165 rootlog.warning( 166 "Using legacy create_machine(), please instantiate the" 167 "Machine class directly, instead" 168 ) 169 170 tmp_dir = get_tmp_dir() 171 172 if args.get("startCommand"): 173 start_command: str = args.get("startCommand", "") 174 cmd = NixStartScript(start_command) 175 name = args.get("name", cmd.machine_name) 176 else: 177 cmd = Machine.create_startcommand(args) # type: ignore 178 name = args.get("name", "machine") 179 180 return Machine( 181 tmp_dir=tmp_dir, 182 out_dir=self.out_dir, 183 start_command=cmd, 184 name=name, 185 keep_vm_state=args.get("keep_vm_state", False), 186 allow_reboot=args.get("allow_reboot", False), 187 ) 188 189 def serial_stdout_on(self) -> None: 190 rootlog._print_serial_logs = True 191 192 def serial_stdout_off(self) -> None: 193 rootlog._print_serial_logs = False 194 195 def check_polling_conditions(self) -> None: 196 for condition in self.polling_conditions: 197 condition.maybe_raise() 198 199 def polling_condition( 200 self, 201 fun_: Optional[Callable] = None, 202 *, 203 seconds_interval: float = 2.0, 204 description: Optional[str] = None, 205 ) -> Union[Callable[[Callable], ContextManager], ContextManager]: 206 driver = self 207 208 class Poll: 209 def __init__(self, fun: Callable): 210 self.condition = PollingCondition( 211 fun, 212 seconds_interval, 213 description, 214 ) 215 216 def __enter__(self) -> None: 217 driver.polling_conditions.append(self.condition) 218 219 def __exit__(self, a, b, c) -> None: # type: ignore 220 res = driver.polling_conditions.pop() 221 assert res is self.condition 222 223 if fun_ is None: 224 return Poll 225 else: 226 return Poll(fun_)