1import os
2import re
3import signal
4import tempfile
5import threading
6from contextlib import contextmanager
7from pathlib import Path
8from typing import Any, Callable, ContextManager, Dict, Iterator, List, Optional, Union
9
10from colorama import Fore, Style
11
12from test_driver.logger import AbstractLogger
13from test_driver.machine import Machine, NixStartScript, retry
14from test_driver.polling_condition import PollingCondition
15from test_driver.vlan import VLan
16
17SENTINEL = object()
18
19
20def get_tmp_dir() -> Path:
21 """Returns a temporary directory that is defined by TMPDIR, TEMP, TMP or CWD
22 Raises an exception in case the retrieved temporary directory is not writeable
23 See https://docs.python.org/3/library/tempfile.html#tempfile.gettempdir
24 """
25 tmp_dir = Path(tempfile.gettempdir())
26 tmp_dir.mkdir(mode=0o700, exist_ok=True)
27 if not tmp_dir.is_dir():
28 raise NotADirectoryError(
29 f"The directory defined by TMPDIR, TEMP, TMP or CWD: {tmp_dir} is not a directory"
30 )
31 if not os.access(tmp_dir, os.W_OK):
32 raise PermissionError(
33 f"The directory defined by TMPDIR, TEMP, TMP, or CWD: {tmp_dir} is not writeable"
34 )
35 return tmp_dir
36
37
38def pythonize_name(name: str) -> str:
39 return re.sub(r"^[^A-z_]|[^A-z0-9_]", "_", name)
40
41
42class Driver:
43 """A handle to the driver that sets up the environment
44 and runs the tests"""
45
46 tests: str
47 vlans: List[VLan]
48 machines: List[Machine]
49 polling_conditions: List[PollingCondition]
50 global_timeout: int
51 race_timer: threading.Timer
52 logger: AbstractLogger
53
54 def __init__(
55 self,
56 start_scripts: List[str],
57 vlans: List[int],
58 tests: str,
59 out_dir: Path,
60 logger: AbstractLogger,
61 keep_vm_state: bool = False,
62 global_timeout: int = 24 * 60 * 60 * 7,
63 ):
64 self.tests = tests
65 self.out_dir = out_dir
66 self.global_timeout = global_timeout
67 self.race_timer = threading.Timer(global_timeout, self.terminate_test)
68 self.logger = logger
69
70 tmp_dir = get_tmp_dir()
71
72 with self.logger.nested("start all VLans"):
73 vlans = list(set(vlans))
74 self.vlans = [VLan(nr, tmp_dir, self.logger) for nr in vlans]
75
76 def cmd(scripts: List[str]) -> Iterator[NixStartScript]:
77 for s in scripts:
78 yield NixStartScript(s)
79
80 self.polling_conditions = []
81
82 self.machines = [
83 Machine(
84 start_command=cmd,
85 keep_vm_state=keep_vm_state,
86 name=cmd.machine_name,
87 tmp_dir=tmp_dir,
88 callbacks=[self.check_polling_conditions],
89 out_dir=self.out_dir,
90 logger=self.logger,
91 )
92 for cmd in cmd(start_scripts)
93 ]
94
95 def __enter__(self) -> "Driver":
96 return self
97
98 def __exit__(self, *_: Any) -> None:
99 with self.logger.nested("cleanup"):
100 self.race_timer.cancel()
101 for machine in self.machines:
102 machine.release()
103
104 def subtest(self, name: str) -> Iterator[None]:
105 """Group logs under a given test name"""
106 with self.logger.subtest(name):
107 try:
108 yield
109 return True
110 except Exception as e:
111 self.logger.error(f'Test "{name}" failed with error: "{e}"')
112 raise e
113
114 def test_symbols(self) -> Dict[str, Any]:
115 @contextmanager
116 def subtest(name: str) -> Iterator[None]:
117 return self.subtest(name)
118
119 general_symbols = dict(
120 start_all=self.start_all,
121 test_script=self.test_script,
122 machines=self.machines,
123 vlans=self.vlans,
124 driver=self,
125 log=self.logger,
126 os=os,
127 create_machine=self.create_machine,
128 subtest=subtest,
129 run_tests=self.run_tests,
130 join_all=self.join_all,
131 retry=retry,
132 serial_stdout_off=self.serial_stdout_off,
133 serial_stdout_on=self.serial_stdout_on,
134 polling_condition=self.polling_condition,
135 Machine=Machine, # for typing
136 )
137 machine_symbols = {pythonize_name(m.name): m for m in self.machines}
138 # If there's exactly one machine, make it available under the name
139 # "machine", even if it's not called that.
140 if len(self.machines) == 1:
141 (machine_symbols["machine"],) = self.machines
142 vlan_symbols = {
143 f"vlan{v.nr}": self.vlans[idx] for idx, v in enumerate(self.vlans)
144 }
145 print(
146 "additionally exposed symbols:\n "
147 + ", ".join(map(lambda m: m.name, self.machines))
148 + ",\n "
149 + ", ".join(map(lambda v: f"vlan{v.nr}", self.vlans))
150 + ",\n "
151 + ", ".join(list(general_symbols.keys()))
152 )
153 return {**general_symbols, **machine_symbols, **vlan_symbols}
154
155 def test_script(self) -> None:
156 """Run the test script"""
157 with self.logger.nested("run the VM test script"):
158 symbols = self.test_symbols() # call eagerly
159 exec(self.tests, symbols, None)
160
161 def run_tests(self) -> None:
162 """Run the test script (for non-interactive test runs)"""
163 self.logger.info(
164 f"Test will time out and terminate in {self.global_timeout} seconds"
165 )
166 self.race_timer.start()
167 self.test_script()
168 # TODO: Collect coverage data
169 for machine in self.machines:
170 if machine.is_up():
171 machine.execute("sync")
172
173 def start_all(self) -> None:
174 """Start all machines"""
175 with self.logger.nested("start all VMs"):
176 for machine in self.machines:
177 machine.start()
178
179 def join_all(self) -> None:
180 """Wait for all machines to shut down"""
181 with self.logger.nested("wait for all VMs to finish"):
182 for machine in self.machines:
183 machine.wait_for_shutdown()
184 self.race_timer.cancel()
185
186 def terminate_test(self) -> None:
187 # This will be usually running in another thread than
188 # the thread actually executing the test script.
189 with self.logger.nested("timeout reached; test terminating..."):
190 for machine in self.machines:
191 machine.release()
192 # As we cannot `sys.exit` from another thread
193 # We can at least force the main thread to get SIGTERM'ed.
194 # This will prevent any user who caught all the exceptions
195 # to swallow them and prevent itself from terminating.
196 os.kill(os.getpid(), signal.SIGTERM)
197
198 def create_machine(
199 self,
200 start_command: str | dict,
201 *,
202 name: Optional[str] = None,
203 keep_vm_state: bool = False,
204 ) -> Machine:
205 # Legacy args handling
206 # FIXME: remove after 24.05
207 if isinstance(start_command, dict):
208 if name is not None or keep_vm_state:
209 raise TypeError(
210 "Dictionary passed to create_machine must be the only argument"
211 )
212
213 args = start_command
214 start_command = args.pop("startCommand", SENTINEL)
215
216 if start_command is SENTINEL:
217 raise TypeError(
218 "Dictionary passed to create_machine must contain startCommand"
219 )
220
221 if not isinstance(start_command, str):
222 raise TypeError(
223 f"startCommand must be a string, got: {repr(start_command)}"
224 )
225
226 name = args.pop("name", None)
227 keep_vm_state = args.pop("keep_vm_state", False)
228
229 if args:
230 raise TypeError(
231 f"Unsupported arguments passed to create_machine: {args}"
232 )
233
234 self.logger.warning(
235 Fore.YELLOW
236 + Style.BRIGHT
237 + "WARNING: Using create_machine with a single dictionary argument is deprecated and will be removed in NixOS 24.11"
238 + Style.RESET_ALL
239 )
240 # End legacy args handling
241
242 tmp_dir = get_tmp_dir()
243
244 cmd = NixStartScript(start_command)
245 name = name or cmd.machine_name
246
247 return Machine(
248 tmp_dir=tmp_dir,
249 out_dir=self.out_dir,
250 start_command=cmd,
251 name=name,
252 keep_vm_state=keep_vm_state,
253 logger=self.logger,
254 )
255
256 def serial_stdout_on(self) -> None:
257 self.logger.print_serial_logs(True)
258
259 def serial_stdout_off(self) -> None:
260 self.logger.print_serial_logs(False)
261
262 def check_polling_conditions(self) -> None:
263 for condition in self.polling_conditions:
264 condition.maybe_raise()
265
266 def polling_condition(
267 self,
268 fun_: Optional[Callable] = None,
269 *,
270 seconds_interval: float = 2.0,
271 description: Optional[str] = None,
272 ) -> Union[Callable[[Callable], ContextManager], ContextManager]:
273 driver = self
274
275 class Poll:
276 def __init__(self, fun: Callable):
277 self.condition = PollingCondition(
278 fun,
279 driver.logger,
280 seconds_interval,
281 description,
282 )
283
284 def __enter__(self) -> None:
285 driver.polling_conditions.append(self.condition)
286
287 def __exit__(self, a, b, c) -> None: # type: ignore
288 res = driver.polling_conditions.pop()
289 assert res is self.condition
290
291 def wait(self, timeout: int = 900) -> None:
292 def condition(last: bool) -> bool:
293 if last:
294 driver.logger.info(
295 f"Last chance for {self.condition.description}"
296 )
297 ret = self.condition.check(force=True)
298 if not ret and not last:
299 driver.logger.info(
300 f"({self.condition.description} failure not fatal yet)"
301 )
302 return ret
303
304 with driver.logger.nested(f"waiting for {self.condition.description}"):
305 retry(condition, timeout=timeout)
306
307 if fun_ is None:
308 return Poll
309 else:
310 return Poll(fun_)