1from contextlib import contextmanager
2from pathlib import Path
3from typing import Any, Dict, Iterator, List, Union, Optional, Callable, ContextManager
4import os
5import tempfile
6
7from test_driver.logger import rootlog
8from test_driver.machine import Machine, NixStartScript, retry
9from test_driver.vlan import VLan
10from test_driver.polling_condition import PollingCondition
11
12
13def get_tmp_dir() -> Path:
14 """Returns a temporary directory that is defined by TMPDIR, TEMP, TMP or CWD
15 Raises an exception in case the retrieved temporary directory is not writeable
16 See https://docs.python.org/3/library/tempfile.html#tempfile.gettempdir
17 """
18 tmp_dir = Path(tempfile.gettempdir())
19 tmp_dir.mkdir(mode=0o700, exist_ok=True)
20 if not tmp_dir.is_dir():
21 raise NotADirectoryError(
22 "The directory defined by TMPDIR, TEMP, TMP or CWD: {0} is not a directory".format(
23 tmp_dir
24 )
25 )
26 if not os.access(tmp_dir, os.W_OK):
27 raise PermissionError(
28 "The directory defined by TMPDIR, TEMP, TMP, or CWD: {0} is not writeable".format(
29 tmp_dir
30 )
31 )
32 return tmp_dir
33
34
35class Driver:
36 """A handle to the driver that sets up the environment
37 and runs the tests"""
38
39 tests: str
40 vlans: List[VLan]
41 machines: List[Machine]
42 polling_conditions: List[PollingCondition]
43
44 def __init__(
45 self,
46 start_scripts: List[str],
47 vlans: List[int],
48 tests: str,
49 out_dir: Path,
50 keep_vm_state: bool = False,
51 ):
52 self.tests = tests
53 self.out_dir = out_dir
54
55 tmp_dir = get_tmp_dir()
56
57 with rootlog.nested("start all VLans"):
58 vlans = list(set(vlans))
59 self.vlans = [VLan(nr, tmp_dir) for nr in vlans]
60
61 def cmd(scripts: List[str]) -> Iterator[NixStartScript]:
62 for s in scripts:
63 yield NixStartScript(s)
64
65 self.polling_conditions = []
66
67 self.machines = [
68 Machine(
69 start_command=cmd,
70 keep_vm_state=keep_vm_state,
71 name=cmd.machine_name,
72 tmp_dir=tmp_dir,
73 callbacks=[self.check_polling_conditions],
74 out_dir=self.out_dir,
75 )
76 for cmd in cmd(start_scripts)
77 ]
78
79 def __enter__(self) -> "Driver":
80 return self
81
82 def __exit__(self, *_: Any) -> None:
83 with rootlog.nested("cleanup"):
84 for machine in self.machines:
85 machine.release()
86
87 def subtest(self, name: str) -> Iterator[None]:
88 """Group logs under a given test name"""
89 with rootlog.nested("subtest: " + name):
90 try:
91 yield
92 return True
93 except Exception as e:
94 rootlog.error(f'Test "{name}" failed with error: "{e}"')
95 raise e
96
97 def test_symbols(self) -> Dict[str, Any]:
98 @contextmanager
99 def subtest(name: str) -> Iterator[None]:
100 return self.subtest(name)
101
102 general_symbols = dict(
103 start_all=self.start_all,
104 test_script=self.test_script,
105 machines=self.machines,
106 vlans=self.vlans,
107 driver=self,
108 log=rootlog,
109 os=os,
110 create_machine=self.create_machine,
111 subtest=subtest,
112 run_tests=self.run_tests,
113 join_all=self.join_all,
114 retry=retry,
115 serial_stdout_off=self.serial_stdout_off,
116 serial_stdout_on=self.serial_stdout_on,
117 polling_condition=self.polling_condition,
118 Machine=Machine, # for typing
119 )
120 machine_symbols = {m.name: m for m in self.machines}
121 # If there's exactly one machine, make it available under the name
122 # "machine", even if it's not called that.
123 if len(self.machines) == 1:
124 (machine_symbols["machine"],) = self.machines
125 vlan_symbols = {
126 f"vlan{v.nr}": self.vlans[idx] for idx, v in enumerate(self.vlans)
127 }
128 print(
129 "additionally exposed symbols:\n "
130 + ", ".join(map(lambda m: m.name, self.machines))
131 + ",\n "
132 + ", ".join(map(lambda v: f"vlan{v.nr}", self.vlans))
133 + ",\n "
134 + ", ".join(list(general_symbols.keys()))
135 )
136 return {**general_symbols, **machine_symbols, **vlan_symbols}
137
138 def test_script(self) -> None:
139 """Run the test script"""
140 with rootlog.nested("run the VM test script"):
141 symbols = self.test_symbols() # call eagerly
142 exec(self.tests, symbols, None)
143
144 def run_tests(self) -> None:
145 """Run the test script (for non-interactive test runs)"""
146 self.test_script()
147 # TODO: Collect coverage data
148 for machine in self.machines:
149 if machine.is_up():
150 machine.execute("sync")
151
152 def start_all(self) -> None:
153 """Start all machines"""
154 with rootlog.nested("start all VMs"):
155 for machine in self.machines:
156 machine.start()
157
158 def join_all(self) -> None:
159 """Wait for all machines to shut down"""
160 with rootlog.nested("wait for all VMs to finish"):
161 for machine in self.machines:
162 machine.wait_for_shutdown()
163
164 def create_machine(self, args: Dict[str, Any]) -> Machine:
165 rootlog.warning(
166 "Using legacy create_machine(), please instantiate the"
167 "Machine class directly, instead"
168 )
169
170 tmp_dir = get_tmp_dir()
171
172 if args.get("startCommand"):
173 start_command: str = args.get("startCommand", "")
174 cmd = NixStartScript(start_command)
175 name = args.get("name", cmd.machine_name)
176 else:
177 cmd = Machine.create_startcommand(args) # type: ignore
178 name = args.get("name", "machine")
179
180 return Machine(
181 tmp_dir=tmp_dir,
182 out_dir=self.out_dir,
183 start_command=cmd,
184 name=name,
185 keep_vm_state=args.get("keep_vm_state", False),
186 allow_reboot=args.get("allow_reboot", False),
187 )
188
189 def serial_stdout_on(self) -> None:
190 rootlog._print_serial_logs = True
191
192 def serial_stdout_off(self) -> None:
193 rootlog._print_serial_logs = False
194
195 def check_polling_conditions(self) -> None:
196 for condition in self.polling_conditions:
197 condition.maybe_raise()
198
199 def polling_condition(
200 self,
201 fun_: Optional[Callable] = None,
202 *,
203 seconds_interval: float = 2.0,
204 description: Optional[str] = None,
205 ) -> Union[Callable[[Callable], ContextManager], ContextManager]:
206 driver = self
207
208 class Poll:
209 def __init__(self, fun: Callable):
210 self.condition = PollingCondition(
211 fun,
212 seconds_interval,
213 description,
214 )
215
216 def __enter__(self) -> None:
217 driver.polling_conditions.append(self.condition)
218
219 def __exit__(self, a, b, c) -> None: # type: ignore
220 res = driver.polling_conditions.pop()
221 assert res is self.condition
222
223 if fun_ is None:
224 return Poll
225 else:
226 return Poll(fun_)