1import atexit
2import codecs
3import os
4import sys
5import time
6import unicodedata
7from abc import ABC, abstractmethod
8from collections.abc import Iterator
9from contextlib import ExitStack, contextmanager
10from pathlib import Path
11from queue import Empty, Queue
12from typing import Any
13from xml.sax.saxutils import XMLGenerator
14from xml.sax.xmlreader import AttributesImpl
15
16from colorama import Fore, Style
17from junit_xml import TestCase, TestSuite
18
19
20class AbstractLogger(ABC):
21 @abstractmethod
22 def log(self, message: str, attributes: dict[str, str] = {}) -> None:
23 pass
24
25 @abstractmethod
26 @contextmanager
27 def subtest(self, name: str, attributes: dict[str, str] = {}) -> Iterator[None]:
28 pass
29
30 @abstractmethod
31 @contextmanager
32 def nested(self, message: str, attributes: dict[str, str] = {}) -> Iterator[None]:
33 pass
34
35 @abstractmethod
36 def info(self, *args, **kwargs) -> None: # type: ignore
37 pass
38
39 @abstractmethod
40 def warning(self, *args, **kwargs) -> None: # type: ignore
41 pass
42
43 @abstractmethod
44 def error(self, *args, **kwargs) -> None: # type: ignore
45 pass
46
47 @abstractmethod
48 def log_test_error(self, *args, **kwargs) -> None: # type:ignore
49 pass
50
51 @abstractmethod
52 def log_serial(self, message: str, machine: str) -> None:
53 pass
54
55 @abstractmethod
56 def print_serial_logs(self, enable: bool) -> None:
57 pass
58
59
60class JunitXMLLogger(AbstractLogger):
61 class TestCaseState:
62 def __init__(self) -> None:
63 self.stdout = ""
64 self.stderr = ""
65 self.failure = False
66
67 def __init__(self, outfile: Path) -> None:
68 self.tests: dict[str, JunitXMLLogger.TestCaseState] = {
69 "main": self.TestCaseState()
70 }
71 self.currentSubtest = "main"
72 self.outfile: Path = outfile
73 self._print_serial_logs = True
74 atexit.register(self.close)
75
76 def log(self, message: str, attributes: dict[str, str] = {}) -> None:
77 self.tests[self.currentSubtest].stdout += message + os.linesep
78
79 @contextmanager
80 def subtest(self, name: str, attributes: dict[str, str] = {}) -> Iterator[None]:
81 old_test = self.currentSubtest
82 self.tests.setdefault(name, self.TestCaseState())
83 self.currentSubtest = name
84
85 yield
86
87 self.currentSubtest = old_test
88
89 @contextmanager
90 def nested(self, message: str, attributes: dict[str, str] = {}) -> Iterator[None]:
91 self.log(message)
92 yield
93
94 def info(self, *args, **kwargs) -> None: # type: ignore
95 self.tests[self.currentSubtest].stdout += args[0] + os.linesep
96
97 def warning(self, *args, **kwargs) -> None: # type: ignore
98 self.tests[self.currentSubtest].stdout += args[0] + os.linesep
99
100 def error(self, *args, **kwargs) -> None: # type: ignore
101 self.tests[self.currentSubtest].stderr += args[0] + os.linesep
102 self.tests[self.currentSubtest].failure = True
103
104 def log_test_error(self, *args, **kwargs) -> None: # type: ignore
105 self.error(*args, **kwargs)
106
107 def log_serial(self, message: str, machine: str) -> None:
108 if not self._print_serial_logs:
109 return
110
111 self.log(f"{machine} # {message}")
112
113 def print_serial_logs(self, enable: bool) -> None:
114 self._print_serial_logs = enable
115
116 def close(self) -> None:
117 with open(self.outfile, "w") as f:
118 test_cases = []
119 for name, test_case_state in self.tests.items():
120 tc = TestCase(
121 name,
122 stdout=test_case_state.stdout,
123 stderr=test_case_state.stderr,
124 )
125 if test_case_state.failure:
126 tc.add_failure_info("test case failed")
127
128 test_cases.append(tc)
129 ts = TestSuite("NixOS integration test", test_cases)
130 f.write(TestSuite.to_xml_string([ts]))
131
132
133class CompositeLogger(AbstractLogger):
134 def __init__(self, logger_list: list[AbstractLogger]) -> None:
135 self.logger_list = logger_list
136
137 def add_logger(self, logger: AbstractLogger) -> None:
138 self.logger_list.append(logger)
139
140 def log(self, message: str, attributes: dict[str, str] = {}) -> None:
141 for logger in self.logger_list:
142 logger.log(message, attributes)
143
144 @contextmanager
145 def subtest(self, name: str, attributes: dict[str, str] = {}) -> Iterator[None]:
146 with ExitStack() as stack:
147 for logger in self.logger_list:
148 stack.enter_context(logger.subtest(name, attributes))
149 yield
150
151 @contextmanager
152 def nested(self, message: str, attributes: dict[str, str] = {}) -> Iterator[None]:
153 with ExitStack() as stack:
154 for logger in self.logger_list:
155 stack.enter_context(logger.nested(message, attributes))
156 yield
157
158 def info(self, *args, **kwargs) -> None: # type: ignore
159 for logger in self.logger_list:
160 logger.info(*args, **kwargs)
161
162 def warning(self, *args, **kwargs) -> None: # type: ignore
163 for logger in self.logger_list:
164 logger.warning(*args, **kwargs)
165
166 def log_test_error(self, *args, **kwargs) -> None: # type: ignore
167 for logger in self.logger_list:
168 logger.log_test_error(*args, **kwargs)
169
170 def error(self, *args, **kwargs) -> None: # type: ignore
171 for logger in self.logger_list:
172 logger.error(*args, **kwargs)
173 sys.exit(1)
174
175 def print_serial_logs(self, enable: bool) -> None:
176 for logger in self.logger_list:
177 logger.print_serial_logs(enable)
178
179 def log_serial(self, message: str, machine: str) -> None:
180 for logger in self.logger_list:
181 logger.log_serial(message, machine)
182
183
184class TerminalLogger(AbstractLogger):
185 def __init__(self) -> None:
186 self._print_serial_logs = True
187
188 def maybe_prefix(self, message: str, attributes: dict[str, str]) -> str:
189 if "machine" in attributes:
190 return f"{attributes['machine']}: {message}"
191 return message
192
193 @staticmethod
194 def _eprint(*args: object, **kwargs: Any) -> None:
195 print(*args, file=sys.stderr, **kwargs)
196
197 def log(self, message: str, attributes: dict[str, str] = {}) -> None:
198 self._eprint(self.maybe_prefix(message, attributes))
199
200 @contextmanager
201 def subtest(self, name: str, attributes: dict[str, str] = {}) -> Iterator[None]:
202 with self.nested("subtest: " + name, attributes):
203 yield
204
205 @contextmanager
206 def nested(self, message: str, attributes: dict[str, str] = {}) -> Iterator[None]:
207 self._eprint(
208 self.maybe_prefix(
209 Style.BRIGHT + Fore.GREEN + message + Style.RESET_ALL, attributes
210 )
211 )
212
213 tic = time.time()
214 yield
215 toc = time.time()
216 self.log(f"(finished: {message}, in {toc - tic:.2f} seconds)", attributes)
217
218 def info(self, *args, **kwargs) -> None: # type: ignore
219 self.log(*args, **kwargs)
220
221 def warning(self, *args, **kwargs) -> None: # type: ignore
222 self.log(*args, **kwargs)
223
224 def error(self, *args, **kwargs) -> None: # type: ignore
225 self.log(*args, **kwargs)
226
227 def print_serial_logs(self, enable: bool) -> None:
228 self._print_serial_logs = enable
229
230 def log_serial(self, message: str, machine: str) -> None:
231 if not self._print_serial_logs:
232 return
233
234 self._eprint(Style.DIM + f"{machine} # {message}" + Style.RESET_ALL)
235
236 def log_test_error(self, *args, **kwargs) -> None: # type: ignore
237 prefix = Fore.RED + "!!! " + Style.RESET_ALL
238 # NOTE: using `warning` instead of `error` to ensure it does not exit after printing the first log
239 self.warning(f"{prefix}{args[0]}", *args[1:], **kwargs)
240
241
242class XMLLogger(AbstractLogger):
243 def __init__(self, outfile: str) -> None:
244 self.logfile_handle = codecs.open(outfile, "wb")
245 self.xml = XMLGenerator(self.logfile_handle, encoding="utf-8")
246 self.queue: Queue[dict[str, str]] = Queue()
247
248 self._print_serial_logs = True
249
250 self.xml.startDocument()
251 self.xml.startElement("logfile", attrs=AttributesImpl({}))
252
253 def close(self) -> None:
254 self.xml.endElement("logfile")
255 self.xml.endDocument()
256 self.logfile_handle.close()
257
258 def sanitise(self, message: str) -> str:
259 return "".join(ch for ch in message if unicodedata.category(ch)[0] != "C")
260
261 def maybe_prefix(self, message: str, attributes: dict[str, str]) -> str:
262 if "machine" in attributes:
263 return f"{attributes['machine']}: {message}"
264 return message
265
266 def log_line(self, message: str, attributes: dict[str, str]) -> None:
267 self.xml.startElement("line", attrs=AttributesImpl(attributes))
268 self.xml.characters(message)
269 self.xml.endElement("line")
270
271 def info(self, *args, **kwargs) -> None: # type: ignore
272 self.log(*args, **kwargs)
273
274 def warning(self, *args, **kwargs) -> None: # type: ignore
275 self.log(*args, **kwargs)
276
277 def error(self, *args, **kwargs) -> None: # type: ignore
278 self.log(*args, **kwargs)
279
280 def log_test_error(self, *args, **kwargs) -> None: # type: ignore
281 self.log(*args, **kwargs)
282
283 def log(self, message: str, attributes: dict[str, str] = {}) -> None:
284 self.drain_log_queue()
285 self.log_line(message, attributes)
286
287 def print_serial_logs(self, enable: bool) -> None:
288 self._print_serial_logs = enable
289
290 def log_serial(self, message: str, machine: str) -> None:
291 if not self._print_serial_logs:
292 return
293
294 self.enqueue({"msg": message, "machine": machine, "type": "serial"})
295
296 def enqueue(self, item: dict[str, str]) -> None:
297 self.queue.put(item)
298
299 def drain_log_queue(self) -> None:
300 try:
301 while True:
302 item = self.queue.get_nowait()
303 msg = self.sanitise(item["msg"])
304 del item["msg"]
305 self.log_line(msg, item)
306 except Empty:
307 pass
308
309 @contextmanager
310 def subtest(self, name: str, attributes: dict[str, str] = {}) -> Iterator[None]:
311 with self.nested("subtest: " + name, attributes):
312 yield
313
314 @contextmanager
315 def nested(self, message: str, attributes: dict[str, str] = {}) -> Iterator[None]:
316 self.xml.startElement("nest", attrs=AttributesImpl({}))
317 self.xml.startElement("head", attrs=AttributesImpl(attributes))
318 self.xml.characters(message)
319 self.xml.endElement("head")
320
321 tic = time.time()
322 self.drain_log_queue()
323 yield
324 self.drain_log_queue()
325 toc = time.time()
326 self.log(f"(finished: {message}, in {toc - tic:.2f} seconds)")
327
328 self.xml.endElement("nest")