1import atexit 2import codecs 3import os 4import sys 5import time 6import unicodedata 7from abc import ABC, abstractmethod 8from collections.abc import Iterator 9from contextlib import ExitStack, contextmanager 10from pathlib import Path 11from queue import Empty, Queue 12from typing import Any 13from xml.sax.saxutils import XMLGenerator 14from xml.sax.xmlreader import AttributesImpl 15 16from colorama import Fore, Style 17from junit_xml import TestCase, TestSuite 18 19 20class AbstractLogger(ABC): 21 @abstractmethod 22 def log(self, message: str, attributes: dict[str, str] = {}) -> None: 23 pass 24 25 @abstractmethod 26 @contextmanager 27 def subtest(self, name: str, attributes: dict[str, str] = {}) -> Iterator[None]: 28 pass 29 30 @abstractmethod 31 @contextmanager 32 def nested(self, message: str, attributes: dict[str, str] = {}) -> Iterator[None]: 33 pass 34 35 @abstractmethod 36 def info(self, *args, **kwargs) -> None: # type: ignore 37 pass 38 39 @abstractmethod 40 def warning(self, *args, **kwargs) -> None: # type: ignore 41 pass 42 43 @abstractmethod 44 def error(self, *args, **kwargs) -> None: # type: ignore 45 pass 46 47 @abstractmethod 48 def log_test_error(self, *args, **kwargs) -> None: # type:ignore 49 pass 50 51 @abstractmethod 52 def log_serial(self, message: str, machine: str) -> None: 53 pass 54 55 @abstractmethod 56 def print_serial_logs(self, enable: bool) -> None: 57 pass 58 59 60class JunitXMLLogger(AbstractLogger): 61 class TestCaseState: 62 def __init__(self) -> None: 63 self.stdout = "" 64 self.stderr = "" 65 self.failure = False 66 67 def __init__(self, outfile: Path) -> None: 68 self.tests: dict[str, JunitXMLLogger.TestCaseState] = { 69 "main": self.TestCaseState() 70 } 71 self.currentSubtest = "main" 72 self.outfile: Path = outfile 73 self._print_serial_logs = True 74 atexit.register(self.close) 75 76 def log(self, message: str, attributes: dict[str, str] = {}) -> None: 77 self.tests[self.currentSubtest].stdout += message + os.linesep 78 79 @contextmanager 80 def subtest(self, name: str, attributes: dict[str, str] = {}) -> Iterator[None]: 81 old_test = self.currentSubtest 82 self.tests.setdefault(name, self.TestCaseState()) 83 self.currentSubtest = name 84 85 yield 86 87 self.currentSubtest = old_test 88 89 @contextmanager 90 def nested(self, message: str, attributes: dict[str, str] = {}) -> Iterator[None]: 91 self.log(message) 92 yield 93 94 def info(self, *args, **kwargs) -> None: # type: ignore 95 self.tests[self.currentSubtest].stdout += args[0] + os.linesep 96 97 def warning(self, *args, **kwargs) -> None: # type: ignore 98 self.tests[self.currentSubtest].stdout += args[0] + os.linesep 99 100 def error(self, *args, **kwargs) -> None: # type: ignore 101 self.tests[self.currentSubtest].stderr += args[0] + os.linesep 102 self.tests[self.currentSubtest].failure = True 103 104 def log_test_error(self, *args, **kwargs) -> None: # type: ignore 105 self.error(*args, **kwargs) 106 107 def log_serial(self, message: str, machine: str) -> None: 108 if not self._print_serial_logs: 109 return 110 111 self.log(f"{machine} # {message}") 112 113 def print_serial_logs(self, enable: bool) -> None: 114 self._print_serial_logs = enable 115 116 def close(self) -> None: 117 with open(self.outfile, "w") as f: 118 test_cases = [] 119 for name, test_case_state in self.tests.items(): 120 tc = TestCase( 121 name, 122 stdout=test_case_state.stdout, 123 stderr=test_case_state.stderr, 124 ) 125 if test_case_state.failure: 126 tc.add_failure_info("test case failed") 127 128 test_cases.append(tc) 129 ts = TestSuite("NixOS integration test", test_cases) 130 f.write(TestSuite.to_xml_string([ts])) 131 132 133class CompositeLogger(AbstractLogger): 134 def __init__(self, logger_list: list[AbstractLogger]) -> None: 135 self.logger_list = logger_list 136 137 def add_logger(self, logger: AbstractLogger) -> None: 138 self.logger_list.append(logger) 139 140 def log(self, message: str, attributes: dict[str, str] = {}) -> None: 141 for logger in self.logger_list: 142 logger.log(message, attributes) 143 144 @contextmanager 145 def subtest(self, name: str, attributes: dict[str, str] = {}) -> Iterator[None]: 146 with ExitStack() as stack: 147 for logger in self.logger_list: 148 stack.enter_context(logger.subtest(name, attributes)) 149 yield 150 151 @contextmanager 152 def nested(self, message: str, attributes: dict[str, str] = {}) -> Iterator[None]: 153 with ExitStack() as stack: 154 for logger in self.logger_list: 155 stack.enter_context(logger.nested(message, attributes)) 156 yield 157 158 def info(self, *args, **kwargs) -> None: # type: ignore 159 for logger in self.logger_list: 160 logger.info(*args, **kwargs) 161 162 def warning(self, *args, **kwargs) -> None: # type: ignore 163 for logger in self.logger_list: 164 logger.warning(*args, **kwargs) 165 166 def log_test_error(self, *args, **kwargs) -> None: # type: ignore 167 for logger in self.logger_list: 168 logger.log_test_error(*args, **kwargs) 169 170 def error(self, *args, **kwargs) -> None: # type: ignore 171 for logger in self.logger_list: 172 logger.error(*args, **kwargs) 173 sys.exit(1) 174 175 def print_serial_logs(self, enable: bool) -> None: 176 for logger in self.logger_list: 177 logger.print_serial_logs(enable) 178 179 def log_serial(self, message: str, machine: str) -> None: 180 for logger in self.logger_list: 181 logger.log_serial(message, machine) 182 183 184class TerminalLogger(AbstractLogger): 185 def __init__(self) -> None: 186 self._print_serial_logs = True 187 188 def maybe_prefix(self, message: str, attributes: dict[str, str]) -> str: 189 if "machine" in attributes: 190 return f"{attributes['machine']}: {message}" 191 return message 192 193 @staticmethod 194 def _eprint(*args: object, **kwargs: Any) -> None: 195 print(*args, file=sys.stderr, **kwargs) 196 197 def log(self, message: str, attributes: dict[str, str] = {}) -> None: 198 self._eprint(self.maybe_prefix(message, attributes)) 199 200 @contextmanager 201 def subtest(self, name: str, attributes: dict[str, str] = {}) -> Iterator[None]: 202 with self.nested("subtest: " + name, attributes): 203 yield 204 205 @contextmanager 206 def nested(self, message: str, attributes: dict[str, str] = {}) -> Iterator[None]: 207 self._eprint( 208 self.maybe_prefix( 209 Style.BRIGHT + Fore.GREEN + message + Style.RESET_ALL, attributes 210 ) 211 ) 212 213 tic = time.time() 214 yield 215 toc = time.time() 216 self.log(f"(finished: {message}, in {toc - tic:.2f} seconds)", attributes) 217 218 def info(self, *args, **kwargs) -> None: # type: ignore 219 self.log(*args, **kwargs) 220 221 def warning(self, *args, **kwargs) -> None: # type: ignore 222 self.log(*args, **kwargs) 223 224 def error(self, *args, **kwargs) -> None: # type: ignore 225 self.log(*args, **kwargs) 226 227 def print_serial_logs(self, enable: bool) -> None: 228 self._print_serial_logs = enable 229 230 def log_serial(self, message: str, machine: str) -> None: 231 if not self._print_serial_logs: 232 return 233 234 self._eprint(Style.DIM + f"{machine} # {message}" + Style.RESET_ALL) 235 236 def log_test_error(self, *args, **kwargs) -> None: # type: ignore 237 prefix = Fore.RED + "!!! " + Style.RESET_ALL 238 # NOTE: using `warning` instead of `error` to ensure it does not exit after printing the first log 239 self.warning(f"{prefix}{args[0]}", *args[1:], **kwargs) 240 241 242class XMLLogger(AbstractLogger): 243 def __init__(self, outfile: str) -> None: 244 self.logfile_handle = codecs.open(outfile, "wb") 245 self.xml = XMLGenerator(self.logfile_handle, encoding="utf-8") 246 self.queue: Queue[dict[str, str]] = Queue() 247 248 self._print_serial_logs = True 249 250 self.xml.startDocument() 251 self.xml.startElement("logfile", attrs=AttributesImpl({})) 252 253 def close(self) -> None: 254 self.xml.endElement("logfile") 255 self.xml.endDocument() 256 self.logfile_handle.close() 257 258 def sanitise(self, message: str) -> str: 259 return "".join(ch for ch in message if unicodedata.category(ch)[0] != "C") 260 261 def maybe_prefix(self, message: str, attributes: dict[str, str]) -> str: 262 if "machine" in attributes: 263 return f"{attributes['machine']}: {message}" 264 return message 265 266 def log_line(self, message: str, attributes: dict[str, str]) -> None: 267 self.xml.startElement("line", attrs=AttributesImpl(attributes)) 268 self.xml.characters(message) 269 self.xml.endElement("line") 270 271 def info(self, *args, **kwargs) -> None: # type: ignore 272 self.log(*args, **kwargs) 273 274 def warning(self, *args, **kwargs) -> None: # type: ignore 275 self.log(*args, **kwargs) 276 277 def error(self, *args, **kwargs) -> None: # type: ignore 278 self.log(*args, **kwargs) 279 280 def log_test_error(self, *args, **kwargs) -> None: # type: ignore 281 self.log(*args, **kwargs) 282 283 def log(self, message: str, attributes: dict[str, str] = {}) -> None: 284 self.drain_log_queue() 285 self.log_line(message, attributes) 286 287 def print_serial_logs(self, enable: bool) -> None: 288 self._print_serial_logs = enable 289 290 def log_serial(self, message: str, machine: str) -> None: 291 if not self._print_serial_logs: 292 return 293 294 self.enqueue({"msg": message, "machine": machine, "type": "serial"}) 295 296 def enqueue(self, item: dict[str, str]) -> None: 297 self.queue.put(item) 298 299 def drain_log_queue(self) -> None: 300 try: 301 while True: 302 item = self.queue.get_nowait() 303 msg = self.sanitise(item["msg"]) 304 del item["msg"] 305 self.log_line(msg, item) 306 except Empty: 307 pass 308 309 @contextmanager 310 def subtest(self, name: str, attributes: dict[str, str] = {}) -> Iterator[None]: 311 with self.nested("subtest: " + name, attributes): 312 yield 313 314 @contextmanager 315 def nested(self, message: str, attributes: dict[str, str] = {}) -> Iterator[None]: 316 self.xml.startElement("nest", attrs=AttributesImpl({})) 317 self.xml.startElement("head", attrs=AttributesImpl(attributes)) 318 self.xml.characters(message) 319 self.xml.endElement("head") 320 321 tic = time.time() 322 self.drain_log_queue() 323 yield 324 self.drain_log_queue() 325 toc = time.time() 326 self.log(f"(finished: {message}, in {toc - tic:.2f} seconds)") 327 328 self.xml.endElement("nest")