1import os
2import re
3import signal
4import tempfile
5import threading
6from contextlib import contextmanager
7from pathlib import Path
8from typing import Any, Callable, ContextManager, Dict, Iterator, List, Optional, Union
9
10from test_driver.logger import rootlog
11from test_driver.machine import Machine, NixStartScript, retry
12from test_driver.polling_condition import PollingCondition
13from test_driver.vlan import VLan
14
15
16def get_tmp_dir() -> Path:
17 """Returns a temporary directory that is defined by TMPDIR, TEMP, TMP or CWD
18 Raises an exception in case the retrieved temporary directory is not writeable
19 See https://docs.python.org/3/library/tempfile.html#tempfile.gettempdir
20 """
21 tmp_dir = Path(tempfile.gettempdir())
22 tmp_dir.mkdir(mode=0o700, exist_ok=True)
23 if not tmp_dir.is_dir():
24 raise NotADirectoryError(
25 f"The directory defined by TMPDIR, TEMP, TMP or CWD: {tmp_dir} is not a directory"
26 )
27 if not os.access(tmp_dir, os.W_OK):
28 raise PermissionError(
29 f"The directory defined by TMPDIR, TEMP, TMP, or CWD: {tmp_dir} is not writeable"
30 )
31 return tmp_dir
32
33
34def pythonize_name(name: str) -> str:
35 return re.sub(r"^[^A-z_]|[^A-z0-9_]", "_", name)
36
37
38class Driver:
39 """A handle to the driver that sets up the environment
40 and runs the tests"""
41
42 tests: str
43 vlans: List[VLan]
44 machines: List[Machine]
45 polling_conditions: List[PollingCondition]
46 global_timeout: int
47 race_timer: threading.Timer
48
49 def __init__(
50 self,
51 start_scripts: List[str],
52 vlans: List[int],
53 tests: str,
54 out_dir: Path,
55 keep_vm_state: bool = False,
56 global_timeout: int = 24 * 60 * 60 * 7,
57 ):
58 self.tests = tests
59 self.out_dir = out_dir
60 self.global_timeout = global_timeout
61 self.race_timer = threading.Timer(global_timeout, self.terminate_test)
62
63 tmp_dir = get_tmp_dir()
64
65 with rootlog.nested("start all VLans"):
66 vlans = list(set(vlans))
67 self.vlans = [VLan(nr, tmp_dir) for nr in vlans]
68
69 def cmd(scripts: List[str]) -> Iterator[NixStartScript]:
70 for s in scripts:
71 yield NixStartScript(s)
72
73 self.polling_conditions = []
74
75 self.machines = [
76 Machine(
77 start_command=cmd,
78 keep_vm_state=keep_vm_state,
79 name=cmd.machine_name,
80 tmp_dir=tmp_dir,
81 callbacks=[self.check_polling_conditions],
82 out_dir=self.out_dir,
83 )
84 for cmd in cmd(start_scripts)
85 ]
86
87 def __enter__(self) -> "Driver":
88 return self
89
90 def __exit__(self, *_: Any) -> None:
91 with rootlog.nested("cleanup"):
92 self.race_timer.cancel()
93 for machine in self.machines:
94 machine.release()
95
96 def subtest(self, name: str) -> Iterator[None]:
97 """Group logs under a given test name"""
98 with rootlog.nested("subtest: " + name):
99 try:
100 yield
101 return True
102 except Exception as e:
103 rootlog.error(f'Test "{name}" failed with error: "{e}"')
104 raise e
105
106 def test_symbols(self) -> Dict[str, Any]:
107 @contextmanager
108 def subtest(name: str) -> Iterator[None]:
109 return self.subtest(name)
110
111 general_symbols = dict(
112 start_all=self.start_all,
113 test_script=self.test_script,
114 machines=self.machines,
115 vlans=self.vlans,
116 driver=self,
117 log=rootlog,
118 os=os,
119 create_machine=self.create_machine,
120 subtest=subtest,
121 run_tests=self.run_tests,
122 join_all=self.join_all,
123 retry=retry,
124 serial_stdout_off=self.serial_stdout_off,
125 serial_stdout_on=self.serial_stdout_on,
126 polling_condition=self.polling_condition,
127 Machine=Machine, # for typing
128 )
129 machine_symbols = {pythonize_name(m.name): m for m in self.machines}
130 # If there's exactly one machine, make it available under the name
131 # "machine", even if it's not called that.
132 if len(self.machines) == 1:
133 (machine_symbols["machine"],) = self.machines
134 vlan_symbols = {
135 f"vlan{v.nr}": self.vlans[idx] for idx, v in enumerate(self.vlans)
136 }
137 print(
138 "additionally exposed symbols:\n "
139 + ", ".join(map(lambda m: m.name, self.machines))
140 + ",\n "
141 + ", ".join(map(lambda v: f"vlan{v.nr}", self.vlans))
142 + ",\n "
143 + ", ".join(list(general_symbols.keys()))
144 )
145 return {**general_symbols, **machine_symbols, **vlan_symbols}
146
147 def test_script(self) -> None:
148 """Run the test script"""
149 with rootlog.nested("run the VM test script"):
150 symbols = self.test_symbols() # call eagerly
151 exec(self.tests, symbols, None)
152
153 def run_tests(self) -> None:
154 """Run the test script (for non-interactive test runs)"""
155 rootlog.info(
156 f"Test will time out and terminate in {self.global_timeout} seconds"
157 )
158 self.race_timer.start()
159 self.test_script()
160 # TODO: Collect coverage data
161 for machine in self.machines:
162 if machine.is_up():
163 machine.execute("sync")
164
165 def start_all(self) -> None:
166 """Start all machines"""
167 with rootlog.nested("start all VMs"):
168 for machine in self.machines:
169 machine.start()
170
171 def join_all(self) -> None:
172 """Wait for all machines to shut down"""
173 with rootlog.nested("wait for all VMs to finish"):
174 for machine in self.machines:
175 machine.wait_for_shutdown()
176 self.race_timer.cancel()
177
178 def terminate_test(self) -> None:
179 # This will be usually running in another thread than
180 # the thread actually executing the test script.
181 with rootlog.nested("timeout reached; test terminating..."):
182 for machine in self.machines:
183 machine.release()
184 # As we cannot `sys.exit` from another thread
185 # We can at least force the main thread to get SIGTERM'ed.
186 # This will prevent any user who caught all the exceptions
187 # to swallow them and prevent itself from terminating.
188 os.kill(os.getpid(), signal.SIGTERM)
189
190 def create_machine(self, args: Dict[str, Any]) -> Machine:
191 tmp_dir = get_tmp_dir()
192
193 if args.get("startCommand"):
194 start_command: str = args.get("startCommand", "")
195 cmd = NixStartScript(start_command)
196 name = args.get("name", cmd.machine_name)
197 else:
198 cmd = Machine.create_startcommand(args) # type: ignore
199 name = args.get("name", "machine")
200
201 return Machine(
202 tmp_dir=tmp_dir,
203 out_dir=self.out_dir,
204 start_command=cmd,
205 name=name,
206 keep_vm_state=args.get("keep_vm_state", False),
207 )
208
209 def serial_stdout_on(self) -> None:
210 rootlog._print_serial_logs = True
211
212 def serial_stdout_off(self) -> None:
213 rootlog._print_serial_logs = False
214
215 def check_polling_conditions(self) -> None:
216 for condition in self.polling_conditions:
217 condition.maybe_raise()
218
219 def polling_condition(
220 self,
221 fun_: Optional[Callable] = None,
222 *,
223 seconds_interval: float = 2.0,
224 description: Optional[str] = None,
225 ) -> Union[Callable[[Callable], ContextManager], ContextManager]:
226 driver = self
227
228 class Poll:
229 def __init__(self, fun: Callable):
230 self.condition = PollingCondition(
231 fun,
232 seconds_interval,
233 description,
234 )
235
236 def __enter__(self) -> None:
237 driver.polling_conditions.append(self.condition)
238
239 def __exit__(self, a, b, c) -> None: # type: ignore
240 res = driver.polling_conditions.pop()
241 assert res is self.condition
242
243 def wait(self, timeout: int = 900) -> None:
244 def condition(last: bool) -> bool:
245 if last:
246 rootlog.info(f"Last chance for {self.condition.description}")
247 ret = self.condition.check(force=True)
248 if not ret and not last:
249 rootlog.info(
250 f"({self.condition.description} failure not fatal yet)"
251 )
252 return ret
253
254 with rootlog.nested(f"waiting for {self.condition.description}"):
255 retry(condition, timeout=timeout)
256
257 if fun_ is None:
258 return Poll
259 else:
260 return Poll(fun_)