1from contextlib import contextmanager
2from pathlib import Path
3from typing import Any, Dict, Iterator, List, Union, Optional, Callable, ContextManager
4import os
5import re
6import tempfile
7
8from test_driver.logger import rootlog
9from test_driver.machine import Machine, NixStartScript, retry
10from test_driver.vlan import VLan
11from test_driver.polling_condition import PollingCondition
12
13
14def get_tmp_dir() -> Path:
15 """Returns a temporary directory that is defined by TMPDIR, TEMP, TMP or CWD
16 Raises an exception in case the retrieved temporary directory is not writeable
17 See https://docs.python.org/3/library/tempfile.html#tempfile.gettempdir
18 """
19 tmp_dir = Path(tempfile.gettempdir())
20 tmp_dir.mkdir(mode=0o700, exist_ok=True)
21 if not tmp_dir.is_dir():
22 raise NotADirectoryError(
23 f"The directory defined by TMPDIR, TEMP, TMP or CWD: {tmp_dir} is not a directory"
24 )
25 if not os.access(tmp_dir, os.W_OK):
26 raise PermissionError(
27 f"The directory defined by TMPDIR, TEMP, TMP, or CWD: {tmp_dir} is not writeable"
28 )
29 return tmp_dir
30
31
32def pythonize_name(name: str) -> str:
33 return re.sub(r"^[^A-z_]|[^A-z0-9_]", "_", name)
34
35
36class Driver:
37 """A handle to the driver that sets up the environment
38 and runs the tests"""
39
40 tests: str
41 vlans: List[VLan]
42 machines: List[Machine]
43 polling_conditions: List[PollingCondition]
44
45 def __init__(
46 self,
47 start_scripts: List[str],
48 vlans: List[int],
49 tests: str,
50 out_dir: Path,
51 keep_vm_state: bool = False,
52 ):
53 self.tests = tests
54 self.out_dir = out_dir
55
56 tmp_dir = get_tmp_dir()
57
58 with rootlog.nested("start all VLans"):
59 vlans = list(set(vlans))
60 self.vlans = [VLan(nr, tmp_dir) for nr in vlans]
61
62 def cmd(scripts: List[str]) -> Iterator[NixStartScript]:
63 for s in scripts:
64 yield NixStartScript(s)
65
66 self.polling_conditions = []
67
68 self.machines = [
69 Machine(
70 start_command=cmd,
71 keep_vm_state=keep_vm_state,
72 name=cmd.machine_name,
73 tmp_dir=tmp_dir,
74 callbacks=[self.check_polling_conditions],
75 out_dir=self.out_dir,
76 )
77 for cmd in cmd(start_scripts)
78 ]
79
80 def __enter__(self) -> "Driver":
81 return self
82
83 def __exit__(self, *_: Any) -> None:
84 with rootlog.nested("cleanup"):
85 for machine in self.machines:
86 machine.release()
87
88 def subtest(self, name: str) -> Iterator[None]:
89 """Group logs under a given test name"""
90 with rootlog.nested("subtest: " + name):
91 try:
92 yield
93 return True
94 except Exception as e:
95 rootlog.error(f'Test "{name}" failed with error: "{e}"')
96 raise e
97
98 def test_symbols(self) -> Dict[str, Any]:
99 @contextmanager
100 def subtest(name: str) -> Iterator[None]:
101 return self.subtest(name)
102
103 general_symbols = dict(
104 start_all=self.start_all,
105 test_script=self.test_script,
106 machines=self.machines,
107 vlans=self.vlans,
108 driver=self,
109 log=rootlog,
110 os=os,
111 create_machine=self.create_machine,
112 subtest=subtest,
113 run_tests=self.run_tests,
114 join_all=self.join_all,
115 retry=retry,
116 serial_stdout_off=self.serial_stdout_off,
117 serial_stdout_on=self.serial_stdout_on,
118 polling_condition=self.polling_condition,
119 Machine=Machine, # for typing
120 )
121 machine_symbols = {pythonize_name(m.name): m for m in self.machines}
122 # If there's exactly one machine, make it available under the name
123 # "machine", even if it's not called that.
124 if len(self.machines) == 1:
125 (machine_symbols["machine"],) = self.machines
126 vlan_symbols = {
127 f"vlan{v.nr}": self.vlans[idx] for idx, v in enumerate(self.vlans)
128 }
129 print(
130 "additionally exposed symbols:\n "
131 + ", ".join(map(lambda m: m.name, self.machines))
132 + ",\n "
133 + ", ".join(map(lambda v: f"vlan{v.nr}", self.vlans))
134 + ",\n "
135 + ", ".join(list(general_symbols.keys()))
136 )
137 return {**general_symbols, **machine_symbols, **vlan_symbols}
138
139 def test_script(self) -> None:
140 """Run the test script"""
141 with rootlog.nested("run the VM test script"):
142 symbols = self.test_symbols() # call eagerly
143 exec(self.tests, symbols, None)
144
145 def run_tests(self) -> None:
146 """Run the test script (for non-interactive test runs)"""
147 self.test_script()
148 # TODO: Collect coverage data
149 for machine in self.machines:
150 if machine.is_up():
151 machine.execute("sync")
152
153 def start_all(self) -> None:
154 """Start all machines"""
155 with rootlog.nested("start all VMs"):
156 for machine in self.machines:
157 machine.start()
158
159 def join_all(self) -> None:
160 """Wait for all machines to shut down"""
161 with rootlog.nested("wait for all VMs to finish"):
162 for machine in self.machines:
163 machine.wait_for_shutdown()
164
165 def create_machine(self, args: Dict[str, Any]) -> Machine:
166 rootlog.warning(
167 "Using legacy create_machine(), please instantiate the"
168 "Machine class directly, instead"
169 )
170
171 tmp_dir = get_tmp_dir()
172
173 if args.get("startCommand"):
174 start_command: str = args.get("startCommand", "")
175 cmd = NixStartScript(start_command)
176 name = args.get("name", cmd.machine_name)
177 else:
178 cmd = Machine.create_startcommand(args) # type: ignore
179 name = args.get("name", "machine")
180
181 return Machine(
182 tmp_dir=tmp_dir,
183 out_dir=self.out_dir,
184 start_command=cmd,
185 name=name,
186 keep_vm_state=args.get("keep_vm_state", False),
187 )
188
189 def serial_stdout_on(self) -> None:
190 rootlog._print_serial_logs = True
191
192 def serial_stdout_off(self) -> None:
193 rootlog._print_serial_logs = False
194
195 def check_polling_conditions(self) -> None:
196 for condition in self.polling_conditions:
197 condition.maybe_raise()
198
199 def polling_condition(
200 self,
201 fun_: Optional[Callable] = None,
202 *,
203 seconds_interval: float = 2.0,
204 description: Optional[str] = None,
205 ) -> Union[Callable[[Callable], ContextManager], ContextManager]:
206 driver = self
207
208 class Poll:
209 def __init__(self, fun: Callable):
210 self.condition = PollingCondition(
211 fun,
212 seconds_interval,
213 description,
214 )
215
216 def __enter__(self) -> None:
217 driver.polling_conditions.append(self.condition)
218
219 def __exit__(self, a, b, c) -> None: # type: ignore
220 res = driver.polling_conditions.pop()
221 assert res is self.condition
222
223 def wait(self, timeout: int = 900) -> None:
224 def condition(last: bool) -> bool:
225 if last:
226 rootlog.info(f"Last chance for {self.condition.description}")
227 ret = self.condition.check(force=True)
228 if not ret and not last:
229 rootlog.info(
230 f"({self.condition.description} failure not fatal yet)"
231 )
232 return ret
233
234 with rootlog.nested(f"waiting for {self.condition.description}"):
235 retry(condition, timeout=timeout)
236
237 if fun_ is None:
238 return Poll
239 else:
240 return Poll(fun_)