1import atexit
2import codecs
3import os
4import sys
5import time
6import unicodedata
7from abc import ABC, abstractmethod
8from contextlib import ExitStack, contextmanager
9from pathlib import Path
10from queue import Empty, Queue
11from typing import Any, Dict, Iterator, List
12from xml.sax.saxutils import XMLGenerator
13from xml.sax.xmlreader import AttributesImpl
14
15from colorama import Fore, Style
16from junit_xml import TestCase, TestSuite
17
18
19class AbstractLogger(ABC):
20 @abstractmethod
21 def log(self, message: str, attributes: Dict[str, str] = {}) -> None:
22 pass
23
24 @abstractmethod
25 @contextmanager
26 def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
27 pass
28
29 @abstractmethod
30 @contextmanager
31 def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
32 pass
33
34 @abstractmethod
35 def info(self, *args, **kwargs) -> None: # type: ignore
36 pass
37
38 @abstractmethod
39 def warning(self, *args, **kwargs) -> None: # type: ignore
40 pass
41
42 @abstractmethod
43 def error(self, *args, **kwargs) -> None: # type: ignore
44 pass
45
46 @abstractmethod
47 def log_serial(self, message: str, machine: str) -> None:
48 pass
49
50 @abstractmethod
51 def print_serial_logs(self, enable: bool) -> None:
52 pass
53
54
55class JunitXMLLogger(AbstractLogger):
56
57 class TestCaseState:
58 def __init__(self) -> None:
59 self.stdout = ""
60 self.stderr = ""
61 self.failure = False
62
63 def __init__(self, outfile: Path) -> None:
64 self.tests: dict[str, JunitXMLLogger.TestCaseState] = {
65 "main": self.TestCaseState()
66 }
67 self.currentSubtest = "main"
68 self.outfile: Path = outfile
69 self._print_serial_logs = True
70 atexit.register(self.close)
71
72 def log(self, message: str, attributes: Dict[str, str] = {}) -> None:
73 self.tests[self.currentSubtest].stdout += message + os.linesep
74
75 @contextmanager
76 def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
77 old_test = self.currentSubtest
78 self.tests.setdefault(name, self.TestCaseState())
79 self.currentSubtest = name
80
81 yield
82
83 self.currentSubtest = old_test
84
85 @contextmanager
86 def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
87 self.log(message)
88 yield
89
90 def info(self, *args, **kwargs) -> None: # type: ignore
91 self.tests[self.currentSubtest].stdout += args[0] + os.linesep
92
93 def warning(self, *args, **kwargs) -> None: # type: ignore
94 self.tests[self.currentSubtest].stdout += args[0] + os.linesep
95
96 def error(self, *args, **kwargs) -> None: # type: ignore
97 self.tests[self.currentSubtest].stderr += args[0] + os.linesep
98 self.tests[self.currentSubtest].failure = True
99
100 def log_serial(self, message: str, machine: str) -> None:
101 if not self._print_serial_logs:
102 return
103
104 self.log(f"{machine} # {message}")
105
106 def print_serial_logs(self, enable: bool) -> None:
107 self._print_serial_logs = enable
108
109 def close(self) -> None:
110 with open(self.outfile, "w") as f:
111 test_cases = []
112 for name, test_case_state in self.tests.items():
113 tc = TestCase(
114 name,
115 stdout=test_case_state.stdout,
116 stderr=test_case_state.stderr,
117 )
118 if test_case_state.failure:
119 tc.add_failure_info("test case failed")
120
121 test_cases.append(tc)
122 ts = TestSuite("NixOS integration test", test_cases)
123 f.write(TestSuite.to_xml_string([ts]))
124
125
126class CompositeLogger(AbstractLogger):
127 def __init__(self, logger_list: List[AbstractLogger]) -> None:
128 self.logger_list = logger_list
129
130 def add_logger(self, logger: AbstractLogger) -> None:
131 self.logger_list.append(logger)
132
133 def log(self, message: str, attributes: Dict[str, str] = {}) -> None:
134 for logger in self.logger_list:
135 logger.log(message, attributes)
136
137 @contextmanager
138 def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
139 with ExitStack() as stack:
140 for logger in self.logger_list:
141 stack.enter_context(logger.subtest(name, attributes))
142 yield
143
144 @contextmanager
145 def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
146 with ExitStack() as stack:
147 for logger in self.logger_list:
148 stack.enter_context(logger.nested(message, attributes))
149 yield
150
151 def info(self, *args, **kwargs) -> None: # type: ignore
152 for logger in self.logger_list:
153 logger.info(*args, **kwargs)
154
155 def warning(self, *args, **kwargs) -> None: # type: ignore
156 for logger in self.logger_list:
157 logger.warning(*args, **kwargs)
158
159 def error(self, *args, **kwargs) -> None: # type: ignore
160 for logger in self.logger_list:
161 logger.error(*args, **kwargs)
162 sys.exit(1)
163
164 def print_serial_logs(self, enable: bool) -> None:
165 for logger in self.logger_list:
166 logger.print_serial_logs(enable)
167
168 def log_serial(self, message: str, machine: str) -> None:
169 for logger in self.logger_list:
170 logger.log_serial(message, machine)
171
172
173class TerminalLogger(AbstractLogger):
174 def __init__(self) -> None:
175 self._print_serial_logs = True
176
177 def maybe_prefix(self, message: str, attributes: Dict[str, str]) -> str:
178 if "machine" in attributes:
179 return f"{attributes['machine']}: {message}"
180 return message
181
182 @staticmethod
183 def _eprint(*args: object, **kwargs: Any) -> None:
184 print(*args, file=sys.stderr, **kwargs)
185
186 def log(self, message: str, attributes: Dict[str, str] = {}) -> None:
187 self._eprint(self.maybe_prefix(message, attributes))
188
189 @contextmanager
190 def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
191 with self.nested("subtest: " + name, attributes):
192 yield
193
194 @contextmanager
195 def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
196 self._eprint(
197 self.maybe_prefix(
198 Style.BRIGHT + Fore.GREEN + message + Style.RESET_ALL, attributes
199 )
200 )
201
202 tic = time.time()
203 yield
204 toc = time.time()
205 self.log(f"(finished: {message}, in {toc - tic:.2f} seconds)")
206
207 def info(self, *args, **kwargs) -> None: # type: ignore
208 self.log(*args, **kwargs)
209
210 def warning(self, *args, **kwargs) -> None: # type: ignore
211 self.log(*args, **kwargs)
212
213 def error(self, *args, **kwargs) -> None: # type: ignore
214 self.log(*args, **kwargs)
215
216 def print_serial_logs(self, enable: bool) -> None:
217 self._print_serial_logs = enable
218
219 def log_serial(self, message: str, machine: str) -> None:
220 if not self._print_serial_logs:
221 return
222
223 self._eprint(Style.DIM + f"{machine} # {message}" + Style.RESET_ALL)
224
225
226class XMLLogger(AbstractLogger):
227 def __init__(self, outfile: str) -> None:
228 self.logfile_handle = codecs.open(outfile, "wb")
229 self.xml = XMLGenerator(self.logfile_handle, encoding="utf-8")
230 self.queue: "Queue[Dict[str, str]]" = Queue()
231
232 self._print_serial_logs = True
233
234 self.xml.startDocument()
235 self.xml.startElement("logfile", attrs=AttributesImpl({}))
236
237 def close(self) -> None:
238 self.xml.endElement("logfile")
239 self.xml.endDocument()
240 self.logfile_handle.close()
241
242 def sanitise(self, message: str) -> str:
243 return "".join(ch for ch in message if unicodedata.category(ch)[0] != "C")
244
245 def maybe_prefix(self, message: str, attributes: Dict[str, str]) -> str:
246 if "machine" in attributes:
247 return f"{attributes['machine']}: {message}"
248 return message
249
250 def log_line(self, message: str, attributes: Dict[str, str]) -> None:
251 self.xml.startElement("line", attrs=AttributesImpl(attributes))
252 self.xml.characters(message)
253 self.xml.endElement("line")
254
255 def info(self, *args, **kwargs) -> None: # type: ignore
256 self.log(*args, **kwargs)
257
258 def warning(self, *args, **kwargs) -> None: # type: ignore
259 self.log(*args, **kwargs)
260
261 def error(self, *args, **kwargs) -> None: # type: ignore
262 self.log(*args, **kwargs)
263
264 def log(self, message: str, attributes: Dict[str, str] = {}) -> None:
265 self.drain_log_queue()
266 self.log_line(message, attributes)
267
268 def print_serial_logs(self, enable: bool) -> None:
269 self._print_serial_logs = enable
270
271 def log_serial(self, message: str, machine: str) -> None:
272 if not self._print_serial_logs:
273 return
274
275 self.enqueue({"msg": message, "machine": machine, "type": "serial"})
276
277 def enqueue(self, item: Dict[str, str]) -> None:
278 self.queue.put(item)
279
280 def drain_log_queue(self) -> None:
281 try:
282 while True:
283 item = self.queue.get_nowait()
284 msg = self.sanitise(item["msg"])
285 del item["msg"]
286 self.log_line(msg, item)
287 except Empty:
288 pass
289
290 @contextmanager
291 def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
292 with self.nested("subtest: " + name, attributes):
293 yield
294
295 @contextmanager
296 def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
297 self.xml.startElement("nest", attrs=AttributesImpl({}))
298 self.xml.startElement("head", attrs=AttributesImpl(attributes))
299 self.xml.characters(message)
300 self.xml.endElement("head")
301
302 tic = time.time()
303 self.drain_log_queue()
304 yield
305 self.drain_log_queue()
306 toc = time.time()
307 self.log(f"(finished: {message}, in {toc - tic:.2f} seconds)")
308
309 self.xml.endElement("nest")