at 22.05-pre 44 kB view raw
1#! /somewhere/python3 2from contextlib import contextmanager, _GeneratorContextManager 3from queue import Queue, Empty 4from typing import Tuple, Any, Callable, Dict, Iterator, Optional, List, Iterable 5from xml.sax.saxutils import XMLGenerator 6from colorama import Style 7from pathlib import Path 8import queue 9import io 10import threading 11import argparse 12import base64 13import codecs 14import os 15import ptpython.repl 16import pty 17import re 18import shlex 19import shutil 20import socket 21import subprocess 22import sys 23import tempfile 24import time 25import unicodedata 26 27CHAR_TO_KEY = { 28 "A": "shift-a", 29 "N": "shift-n", 30 "-": "0x0C", 31 "_": "shift-0x0C", 32 "B": "shift-b", 33 "O": "shift-o", 34 "=": "0x0D", 35 "+": "shift-0x0D", 36 "C": "shift-c", 37 "P": "shift-p", 38 "[": "0x1A", 39 "{": "shift-0x1A", 40 "D": "shift-d", 41 "Q": "shift-q", 42 "]": "0x1B", 43 "}": "shift-0x1B", 44 "E": "shift-e", 45 "R": "shift-r", 46 ";": "0x27", 47 ":": "shift-0x27", 48 "F": "shift-f", 49 "S": "shift-s", 50 "'": "0x28", 51 '"': "shift-0x28", 52 "G": "shift-g", 53 "T": "shift-t", 54 "`": "0x29", 55 "~": "shift-0x29", 56 "H": "shift-h", 57 "U": "shift-u", 58 "\\": "0x2B", 59 "|": "shift-0x2B", 60 "I": "shift-i", 61 "V": "shift-v", 62 ",": "0x33", 63 "<": "shift-0x33", 64 "J": "shift-j", 65 "W": "shift-w", 66 ".": "0x34", 67 ">": "shift-0x34", 68 "K": "shift-k", 69 "X": "shift-x", 70 "/": "0x35", 71 "?": "shift-0x35", 72 "L": "shift-l", 73 "Y": "shift-y", 74 " ": "spc", 75 "M": "shift-m", 76 "Z": "shift-z", 77 "\n": "ret", 78 "!": "shift-0x02", 79 "@": "shift-0x03", 80 "#": "shift-0x04", 81 "$": "shift-0x05", 82 "%": "shift-0x06", 83 "^": "shift-0x07", 84 "&": "shift-0x08", 85 "*": "shift-0x09", 86 "(": "shift-0x0A", 87 ")": "shift-0x0B", 88} 89 90 91class Logger: 92 def __init__(self) -> None: 93 self.logfile = os.environ.get("LOGFILE", "/dev/null") 94 self.logfile_handle = codecs.open(self.logfile, "wb") 95 self.xml = XMLGenerator(self.logfile_handle, encoding="utf-8") 96 self.queue: "Queue[Dict[str, str]]" = Queue() 97 98 self.xml.startDocument() 99 self.xml.startElement("logfile", attrs={}) 100 101 self._print_serial_logs = True 102 103 @staticmethod 104 def _eprint(*args: object, **kwargs: Any) -> None: 105 print(*args, file=sys.stderr, **kwargs) 106 107 def close(self) -> None: 108 self.xml.endElement("logfile") 109 self.xml.endDocument() 110 self.logfile_handle.close() 111 112 def sanitise(self, message: str) -> str: 113 return "".join(ch for ch in message if unicodedata.category(ch)[0] != "C") 114 115 def maybe_prefix(self, message: str, attributes: Dict[str, str]) -> str: 116 if "machine" in attributes: 117 return "{}: {}".format(attributes["machine"], message) 118 return message 119 120 def log_line(self, message: str, attributes: Dict[str, str]) -> None: 121 self.xml.startElement("line", attributes) 122 self.xml.characters(message) 123 self.xml.endElement("line") 124 125 def info(self, *args, **kwargs) -> None: # type: ignore 126 self.log(*args, **kwargs) 127 128 def warning(self, *args, **kwargs) -> None: # type: ignore 129 self.log(*args, **kwargs) 130 131 def error(self, *args, **kwargs) -> None: # type: ignore 132 self.log(*args, **kwargs) 133 sys.exit(1) 134 135 def log(self, message: str, attributes: Dict[str, str] = {}) -> None: 136 self._eprint(self.maybe_prefix(message, attributes)) 137 self.drain_log_queue() 138 self.log_line(message, attributes) 139 140 def log_serial(self, message: str, machine: str) -> None: 141 self.enqueue({"msg": message, "machine": machine, "type": "serial"}) 142 if self._print_serial_logs: 143 self._eprint( 144 Style.DIM + "{} # {}".format(machine, message) + Style.RESET_ALL 145 ) 146 147 def enqueue(self, item: Dict[str, str]) -> None: 148 self.queue.put(item) 149 150 def drain_log_queue(self) -> None: 151 try: 152 while True: 153 item = self.queue.get_nowait() 154 msg = self.sanitise(item["msg"]) 155 del item["msg"] 156 self.log_line(msg, item) 157 except Empty: 158 pass 159 160 @contextmanager 161 def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]: 162 self._eprint(self.maybe_prefix(message, attributes)) 163 164 self.xml.startElement("nest", attrs={}) 165 self.xml.startElement("head", attributes) 166 self.xml.characters(message) 167 self.xml.endElement("head") 168 169 tic = time.time() 170 self.drain_log_queue() 171 yield 172 self.drain_log_queue() 173 toc = time.time() 174 self.log("({:.2f} seconds)".format(toc - tic)) 175 176 self.xml.endElement("nest") 177 178 179rootlog = Logger() 180 181 182def make_command(args: list) -> str: 183 return " ".join(map(shlex.quote, (map(str, args)))) 184 185 186def retry(fn: Callable, timeout: int = 900) -> None: 187 """Call the given function repeatedly, with 1 second intervals, 188 until it returns True or a timeout is reached. 189 """ 190 191 for _ in range(timeout): 192 if fn(False): 193 return 194 time.sleep(1) 195 196 if not fn(True): 197 raise Exception(f"action timed out after {timeout} seconds") 198 199 200def _perform_ocr_on_screenshot( 201 screenshot_path: str, model_ids: Iterable[int] 202) -> List[str]: 203 if shutil.which("tesseract") is None: 204 raise Exception("OCR requested but enableOCR is false") 205 206 magick_args = ( 207 "-filter Catrom -density 72 -resample 300 " 208 + "-contrast -normalize -despeckle -type grayscale " 209 + "-sharpen 1 -posterize 3 -negate -gamma 100 " 210 + "-blur 1x65535" 211 ) 212 213 tess_args = f"-c debug_file=/dev/null --psm 11" 214 215 cmd = f"convert {magick_args} {screenshot_path} tiff:{screenshot_path}.tiff" 216 ret = subprocess.run(cmd, shell=True, capture_output=True) 217 if ret.returncode != 0: 218 raise Exception(f"TIFF conversion failed with exit code {ret.returncode}") 219 220 model_results = [] 221 for model_id in model_ids: 222 cmd = f"tesseract {screenshot_path}.tiff - {tess_args} --oem {model_id}" 223 ret = subprocess.run(cmd, shell=True, capture_output=True) 224 if ret.returncode != 0: 225 raise Exception(f"OCR failed with exit code {ret.returncode}") 226 model_results.append(ret.stdout.decode("utf-8")) 227 228 return model_results 229 230 231class StartCommand: 232 """The Base Start Command knows how to append the necesary 233 runtime qemu options as determined by a particular test driver 234 run. Any such start command is expected to happily receive and 235 append additional qemu args. 236 """ 237 238 _cmd: str 239 240 def cmd( 241 self, 242 monitor_socket_path: Path, 243 shell_socket_path: Path, 244 allow_reboot: bool = False, # TODO: unused, legacy? 245 ) -> str: 246 display_opts = "" 247 display_available = any(x in os.environ for x in ["DISPLAY", "WAYLAND_DISPLAY"]) 248 if not display_available: 249 display_opts += " -nographic" 250 251 # qemu options 252 qemu_opts = "" 253 qemu_opts += ( 254 "" 255 if allow_reboot 256 else " -no-reboot" 257 " -device virtio-serial" 258 " -device virtconsole,chardev=shell" 259 " -device virtio-rng-pci" 260 " -serial stdio" 261 ) 262 # TODO: qemu script already catpures this env variable, legacy? 263 qemu_opts += " " + os.environ.get("QEMU_OPTS", "") 264 265 return ( 266 f"{self._cmd}" 267 f" -monitor unix:{monitor_socket_path}" 268 f" -chardev socket,id=shell,path={shell_socket_path}" 269 f"{qemu_opts}" 270 f"{display_opts}" 271 ) 272 273 @staticmethod 274 def build_environment( 275 state_dir: Path, 276 shared_dir: Path, 277 ) -> dict: 278 # We make a copy to not update the current environment 279 env = dict(os.environ) 280 env.update( 281 { 282 "TMPDIR": str(state_dir), 283 "SHARED_DIR": str(shared_dir), 284 "USE_TMPDIR": "1", 285 } 286 ) 287 return env 288 289 def run( 290 self, 291 state_dir: Path, 292 shared_dir: Path, 293 monitor_socket_path: Path, 294 shell_socket_path: Path, 295 ) -> subprocess.Popen: 296 return subprocess.Popen( 297 self.cmd(monitor_socket_path, shell_socket_path), 298 stdin=subprocess.DEVNULL, 299 stdout=subprocess.PIPE, 300 stderr=subprocess.STDOUT, 301 shell=True, 302 cwd=state_dir, 303 env=self.build_environment(state_dir, shared_dir), 304 ) 305 306 307class NixStartScript(StartCommand): 308 """A start script from nixos/modules/virtualiation/qemu-vm.nix 309 that also satisfies the requirement of the BaseStartCommand. 310 These Nix commands have the particular charactersitic that the 311 machine name can be extracted out of them via a regex match. 312 (Admittedly a _very_ implicit contract, evtl. TODO fix) 313 """ 314 315 def __init__(self, script: str): 316 self._cmd = script 317 318 @property 319 def machine_name(self) -> str: 320 match = re.search("run-(.+)-vm$", self._cmd) 321 name = "machine" 322 if match: 323 name = match.group(1) 324 return name 325 326 327class LegacyStartCommand(StartCommand): 328 """Used in some places to create an ad-hoc machine instead of 329 using nix test instrumentation + module system for that purpose. 330 Legacy. 331 """ 332 333 def __init__( 334 self, 335 netBackendArgs: Optional[str] = None, 336 netFrontendArgs: Optional[str] = None, 337 hda: Optional[Tuple[Path, str]] = None, 338 cdrom: Optional[str] = None, 339 usb: Optional[str] = None, 340 bios: Optional[str] = None, 341 qemuFlags: Optional[str] = None, 342 ): 343 self._cmd = "qemu-kvm -m 384" 344 345 # networking 346 net_backend = "-netdev user,id=net0" 347 net_frontend = "-device virtio-net-pci,netdev=net0" 348 if netBackendArgs is not None: 349 net_backend += "," + netBackendArgs 350 if netFrontendArgs is not None: 351 net_frontend += "," + netFrontendArgs 352 self._cmd += f" {net_backend} {net_frontend}" 353 354 # hda 355 hda_cmd = "" 356 if hda is not None: 357 hda_path = hda[0].resolve() 358 hda_interface = hda[1] 359 if hda_interface == "scsi": 360 hda_cmd += ( 361 f" -drive id=hda,file={hda_path},werror=report,if=none" 362 " -device scsi-hd,drive=hda" 363 ) 364 else: 365 hda_cmd += f" -drive file={hda_path},if={hda_interface},werror=report" 366 self._cmd += hda_cmd 367 368 # cdrom 369 if cdrom is not None: 370 self._cmd += f" -cdrom {cdrom}" 371 372 # usb 373 usb_cmd = "" 374 if usb is not None: 375 # https://github.com/qemu/qemu/blob/master/docs/usb2.txt 376 usb_cmd += ( 377 " -device usb-ehci" 378 f" -drive id=usbdisk,file={usb},if=none,readonly" 379 " -device usb-storage,drive=usbdisk " 380 ) 381 self._cmd += usb_cmd 382 383 # bios 384 if bios is not None: 385 self._cmd += f" -bios {bios}" 386 387 # qemu flags 388 if qemuFlags is not None: 389 self._cmd += f" {qemuFlags}" 390 391 392class Machine: 393 """A handle to the machine with this name, that also knows how to manage 394 the machine lifecycle with the help of a start script / command.""" 395 396 name: str 397 tmp_dir: Path 398 shared_dir: Path 399 state_dir: Path 400 monitor_path: Path 401 shell_path: Path 402 403 start_command: StartCommand 404 keep_vm_state: bool 405 allow_reboot: bool 406 407 process: Optional[subprocess.Popen] 408 pid: Optional[int] 409 monitor: Optional[socket.socket] 410 shell: Optional[socket.socket] 411 serial_thread: Optional[threading.Thread] 412 413 booted: bool 414 connected: bool 415 # Store last serial console lines for use 416 # of wait_for_console_text 417 last_lines: Queue = Queue() 418 419 def __repr__(self) -> str: 420 return f"<Machine '{self.name}'>" 421 422 def __init__( 423 self, 424 tmp_dir: Path, 425 start_command: StartCommand, 426 name: str = "machine", 427 keep_vm_state: bool = False, 428 allow_reboot: bool = False, 429 ) -> None: 430 self.tmp_dir = tmp_dir 431 self.keep_vm_state = keep_vm_state 432 self.allow_reboot = allow_reboot 433 self.name = name 434 self.start_command = start_command 435 436 # set up directories 437 self.shared_dir = self.tmp_dir / "shared-xchg" 438 self.shared_dir.mkdir(mode=0o700, exist_ok=True) 439 440 self.state_dir = self.tmp_dir / f"vm-state-{self.name}" 441 self.monitor_path = self.state_dir / "monitor" 442 self.shell_path = self.state_dir / "shell" 443 if (not self.keep_vm_state) and self.state_dir.exists(): 444 self.cleanup_statedir() 445 self.state_dir.mkdir(mode=0o700, exist_ok=True) 446 447 self.process = None 448 self.pid = None 449 self.monitor = None 450 self.shell = None 451 self.serial_thread = None 452 453 self.booted = False 454 self.connected = False 455 456 @staticmethod 457 def create_startcommand(args: Dict[str, str]) -> StartCommand: 458 rootlog.warning( 459 "Using legacy create_startcommand()," 460 "please use proper nix test vm instrumentation, instead" 461 "to generate the appropriate nixos test vm qemu startup script" 462 ) 463 hda = None 464 if args.get("hda"): 465 hda_arg: str = args.get("hda", "") 466 hda_arg_path: Path = Path(hda_arg) 467 hda = (hda_arg_path, args.get("hdaInterface", "")) 468 return LegacyStartCommand( 469 netBackendArgs=args.get("netBackendArgs"), 470 netFrontendArgs=args.get("netFrontendArgs"), 471 hda=hda, 472 cdrom=args.get("cdrom"), 473 usb=args.get("usb"), 474 bios=args.get("bios"), 475 qemuFlags=args.get("qemuFlags"), 476 ) 477 478 def is_up(self) -> bool: 479 return self.booted and self.connected 480 481 def log(self, msg: str) -> None: 482 rootlog.log(msg, {"machine": self.name}) 483 484 def log_serial(self, msg: str) -> None: 485 rootlog.log_serial(msg, self.name) 486 487 def nested(self, msg: str, attrs: Dict[str, str] = {}) -> _GeneratorContextManager: 488 my_attrs = {"machine": self.name} 489 my_attrs.update(attrs) 490 return rootlog.nested(msg, my_attrs) 491 492 def wait_for_monitor_prompt(self) -> str: 493 assert self.monitor is not None 494 answer = "" 495 while True: 496 undecoded_answer = self.monitor.recv(1024) 497 if not undecoded_answer: 498 break 499 answer += undecoded_answer.decode() 500 if answer.endswith("(qemu) "): 501 break 502 return answer 503 504 def send_monitor_command(self, command: str) -> str: 505 message = ("{}\n".format(command)).encode() 506 self.log("sending monitor command: {}".format(command)) 507 assert self.monitor is not None 508 self.monitor.send(message) 509 return self.wait_for_monitor_prompt() 510 511 def wait_for_unit(self, unit: str, user: Optional[str] = None) -> None: 512 """Wait for a systemd unit to get into "active" state. 513 Throws exceptions on "failed" and "inactive" states as well as 514 after timing out. 515 """ 516 517 def check_active(_: Any) -> bool: 518 info = self.get_unit_info(unit, user) 519 state = info["ActiveState"] 520 if state == "failed": 521 raise Exception('unit "{}" reached state "{}"'.format(unit, state)) 522 523 if state == "inactive": 524 status, jobs = self.systemctl("list-jobs --full 2>&1", user) 525 if "No jobs" in jobs: 526 info = self.get_unit_info(unit, user) 527 if info["ActiveState"] == state: 528 raise Exception( 529 ( 530 'unit "{}" is inactive and there ' "are no pending jobs" 531 ).format(unit) 532 ) 533 534 return state == "active" 535 536 retry(check_active) 537 538 def get_unit_info(self, unit: str, user: Optional[str] = None) -> Dict[str, str]: 539 status, lines = self.systemctl('--no-pager show "{}"'.format(unit), user) 540 if status != 0: 541 raise Exception( 542 'retrieving systemctl info for unit "{}" {} failed with exit code {}'.format( 543 unit, "" if user is None else 'under user "{}"'.format(user), status 544 ) 545 ) 546 547 line_pattern = re.compile(r"^([^=]+)=(.*)$") 548 549 def tuple_from_line(line: str) -> Tuple[str, str]: 550 match = line_pattern.match(line) 551 assert match is not None 552 return match[1], match[2] 553 554 return dict( 555 tuple_from_line(line) 556 for line in lines.split("\n") 557 if line_pattern.match(line) 558 ) 559 560 def systemctl(self, q: str, user: Optional[str] = None) -> Tuple[int, str]: 561 if user is not None: 562 q = q.replace("'", "\\'") 563 return self.execute( 564 ( 565 "su -l {} --shell /bin/sh -c " 566 "$'XDG_RUNTIME_DIR=/run/user/`id -u` " 567 "systemctl --user {}'" 568 ).format(user, q) 569 ) 570 return self.execute("systemctl {}".format(q)) 571 572 def require_unit_state(self, unit: str, require_state: str = "active") -> None: 573 with self.nested( 574 "checking if unit ‘{}’ has reached state '{}'".format(unit, require_state) 575 ): 576 info = self.get_unit_info(unit) 577 state = info["ActiveState"] 578 if state != require_state: 579 raise Exception( 580 "Expected unit ‘{}’ to to be in state ".format(unit) 581 + "'{}' but it is in state ‘{}".format(require_state, state) 582 ) 583 584 def _next_newline_closed_block_from_shell(self) -> str: 585 assert self.shell 586 output_buffer = [] 587 while True: 588 # This receives up to 4096 bytes from the socket 589 chunk = self.shell.recv(4096) 590 if not chunk: 591 # Probably a broken pipe, return the output we have 592 break 593 594 decoded = chunk.decode() 595 output_buffer += [decoded] 596 if decoded[-1] == "\n": 597 break 598 return "".join(output_buffer) 599 600 def execute(self, command: str, check_return: bool = True) -> Tuple[int, str]: 601 self.connect() 602 603 out_command = f"( set -euo pipefail; {command} ) | (base64 --wrap 0; echo)\n" 604 assert self.shell 605 self.shell.send(out_command.encode()) 606 607 # Get the output 608 output = base64.b64decode(self._next_newline_closed_block_from_shell()) 609 610 if not check_return: 611 return (-1, output.decode()) 612 613 # Get the return code 614 self.shell.send("echo ${PIPESTATUS[0]}\n".encode()) 615 rc = int(self._next_newline_closed_block_from_shell().strip()) 616 617 return (rc, output.decode()) 618 619 def shell_interact(self) -> None: 620 """Allows you to interact with the guest shell 621 622 Should only be used during test development, not in the production test.""" 623 self.connect() 624 self.log("Terminal is ready (there is no prompt):") 625 626 assert self.shell 627 subprocess.run( 628 ["socat", "READLINE", f"FD:{self.shell.fileno()}"], 629 pass_fds=[self.shell.fileno()], 630 ) 631 632 def succeed(self, *commands: str) -> str: 633 """Execute each command and check that it succeeds.""" 634 output = "" 635 for command in commands: 636 with self.nested("must succeed: {}".format(command)): 637 (status, out) = self.execute(command) 638 if status != 0: 639 self.log("output: {}".format(out)) 640 raise Exception( 641 "command `{}` failed (exit code {})".format(command, status) 642 ) 643 output += out 644 return output 645 646 def fail(self, *commands: str) -> str: 647 """Execute each command and check that it fails.""" 648 output = "" 649 for command in commands: 650 with self.nested("must fail: {}".format(command)): 651 (status, out) = self.execute(command) 652 if status == 0: 653 raise Exception( 654 "command `{}` unexpectedly succeeded".format(command) 655 ) 656 output += out 657 return output 658 659 def wait_until_succeeds(self, command: str, timeout: int = 900) -> str: 660 """Wait until a command returns success and return its output. 661 Throws an exception on timeout. 662 """ 663 output = "" 664 665 def check_success(_: Any) -> bool: 666 nonlocal output 667 status, output = self.execute(command) 668 return status == 0 669 670 with self.nested("waiting for success: {}".format(command)): 671 retry(check_success, timeout) 672 return output 673 674 def wait_until_fails(self, command: str) -> str: 675 """Wait until a command returns failure. 676 Throws an exception on timeout. 677 """ 678 output = "" 679 680 def check_failure(_: Any) -> bool: 681 nonlocal output 682 status, output = self.execute(command) 683 return status != 0 684 685 with self.nested("waiting for failure: {}".format(command)): 686 retry(check_failure) 687 return output 688 689 def wait_for_shutdown(self) -> None: 690 if not self.booted: 691 return 692 693 with self.nested("waiting for the VM to power off"): 694 sys.stdout.flush() 695 assert self.process 696 self.process.wait() 697 698 self.pid = None 699 self.booted = False 700 self.connected = False 701 702 def get_tty_text(self, tty: str) -> str: 703 status, output = self.execute( 704 "fold -w$(stty -F /dev/tty{0} size | " 705 "awk '{{print $2}}') /dev/vcs{0}".format(tty) 706 ) 707 return output 708 709 def wait_until_tty_matches(self, tty: str, regexp: str) -> None: 710 """Wait until the visible output on the chosen TTY matches regular 711 expression. Throws an exception on timeout. 712 """ 713 matcher = re.compile(regexp) 714 715 def tty_matches(last: bool) -> bool: 716 text = self.get_tty_text(tty) 717 if last: 718 self.log( 719 f"Last chance to match /{regexp}/ on TTY{tty}, " 720 f"which currently contains: {text}" 721 ) 722 return len(matcher.findall(text)) > 0 723 724 with self.nested("waiting for {} to appear on tty {}".format(regexp, tty)): 725 retry(tty_matches) 726 727 def send_chars(self, chars: List[str]) -> None: 728 with self.nested("sending keys ‘{}".format(chars)): 729 for char in chars: 730 self.send_key(char) 731 732 def wait_for_file(self, filename: str) -> None: 733 """Waits until the file exists in machine's file system.""" 734 735 def check_file(_: Any) -> bool: 736 status, _ = self.execute("test -e {}".format(filename)) 737 return status == 0 738 739 with self.nested("waiting for file ‘{}".format(filename)): 740 retry(check_file) 741 742 def wait_for_open_port(self, port: int) -> None: 743 def port_is_open(_: Any) -> bool: 744 status, _ = self.execute("nc -z localhost {}".format(port)) 745 return status == 0 746 747 with self.nested("waiting for TCP port {}".format(port)): 748 retry(port_is_open) 749 750 def wait_for_closed_port(self, port: int) -> None: 751 def port_is_closed(_: Any) -> bool: 752 status, _ = self.execute("nc -z localhost {}".format(port)) 753 return status != 0 754 755 retry(port_is_closed) 756 757 def start_job(self, jobname: str, user: Optional[str] = None) -> Tuple[int, str]: 758 return self.systemctl("start {}".format(jobname), user) 759 760 def stop_job(self, jobname: str, user: Optional[str] = None) -> Tuple[int, str]: 761 return self.systemctl("stop {}".format(jobname), user) 762 763 def wait_for_job(self, jobname: str) -> None: 764 self.wait_for_unit(jobname) 765 766 def connect(self) -> None: 767 if self.connected: 768 return 769 770 with self.nested("waiting for the VM to finish booting"): 771 self.start() 772 773 assert self.shell 774 775 tic = time.time() 776 self.shell.recv(1024) 777 # TODO: Timeout 778 toc = time.time() 779 780 self.log("connected to guest root shell") 781 self.log("(connecting took {:.2f} seconds)".format(toc - tic)) 782 self.connected = True 783 784 def screenshot(self, filename: str) -> None: 785 out_dir = os.environ.get("out", os.getcwd()) 786 word_pattern = re.compile(r"^\w+$") 787 if word_pattern.match(filename): 788 filename = os.path.join(out_dir, "{}.png".format(filename)) 789 tmp = "{}.ppm".format(filename) 790 791 with self.nested( 792 "making screenshot {}".format(filename), 793 {"image": os.path.basename(filename)}, 794 ): 795 self.send_monitor_command("screendump {}".format(tmp)) 796 ret = subprocess.run("pnmtopng {} > {}".format(tmp, filename), shell=True) 797 os.unlink(tmp) 798 if ret.returncode != 0: 799 raise Exception("Cannot convert screenshot") 800 801 def copy_from_host_via_shell(self, source: str, target: str) -> None: 802 """Copy a file from the host into the guest by piping it over the 803 shell into the destination file. Works without host-guest shared folder. 804 Prefer copy_from_host for whenever possible. 805 """ 806 with open(source, "rb") as fh: 807 content_b64 = base64.b64encode(fh.read()).decode() 808 self.succeed( 809 f"mkdir -p $(dirname {target})", 810 f"echo -n {content_b64} | base64 -d > {target}", 811 ) 812 813 def copy_from_host(self, source: str, target: str) -> None: 814 """Copy a file from the host into the guest via the `shared_dir` shared 815 among all the VMs (using a temporary directory). 816 """ 817 host_src = Path(source) 818 vm_target = Path(target) 819 with tempfile.TemporaryDirectory(dir=self.shared_dir) as shared_td: 820 shared_temp = Path(shared_td) 821 host_intermediate = shared_temp / host_src.name 822 vm_shared_temp = Path("/tmp/shared") / shared_temp.name 823 vm_intermediate = vm_shared_temp / host_src.name 824 825 self.succeed(make_command(["mkdir", "-p", vm_shared_temp])) 826 if host_src.is_dir(): 827 shutil.copytree(host_src, host_intermediate) 828 else: 829 shutil.copy(host_src, host_intermediate) 830 self.succeed(make_command(["mkdir", "-p", vm_target.parent])) 831 self.succeed(make_command(["cp", "-r", vm_intermediate, vm_target])) 832 833 def copy_from_vm(self, source: str, target_dir: str = "") -> None: 834 """Copy a file from the VM (specified by an in-VM source path) to a path 835 relative to `$out`. The file is copied via the `shared_dir` shared among 836 all the VMs (using a temporary directory). 837 """ 838 # Compute the source, target, and intermediate shared file names 839 out_dir = Path(os.environ.get("out", os.getcwd())) 840 vm_src = Path(source) 841 with tempfile.TemporaryDirectory(dir=self.shared_dir) as shared_td: 842 shared_temp = Path(shared_td) 843 vm_shared_temp = Path("/tmp/shared") / shared_temp.name 844 vm_intermediate = vm_shared_temp / vm_src.name 845 intermediate = shared_temp / vm_src.name 846 # Copy the file to the shared directory inside VM 847 self.succeed(make_command(["mkdir", "-p", vm_shared_temp])) 848 self.succeed(make_command(["cp", "-r", vm_src, vm_intermediate])) 849 abs_target = out_dir / target_dir / vm_src.name 850 abs_target.parent.mkdir(exist_ok=True, parents=True) 851 # Copy the file from the shared directory outside VM 852 if intermediate.is_dir(): 853 shutil.copytree(intermediate, abs_target) 854 else: 855 shutil.copy(intermediate, abs_target) 856 857 def dump_tty_contents(self, tty: str) -> None: 858 """Debugging: Dump the contents of the TTY<n>""" 859 self.execute("fold -w 80 /dev/vcs{} | systemd-cat".format(tty)) 860 861 def _get_screen_text_variants(self, model_ids: Iterable[int]) -> List[str]: 862 with tempfile.TemporaryDirectory() as tmpdir: 863 screenshot_path = os.path.join(tmpdir, "ppm") 864 self.send_monitor_command(f"screendump {screenshot_path}") 865 return _perform_ocr_on_screenshot(screenshot_path, model_ids) 866 867 def get_screen_text_variants(self) -> List[str]: 868 return self._get_screen_text_variants([0, 1, 2]) 869 870 def get_screen_text(self) -> str: 871 return self._get_screen_text_variants([2])[0] 872 873 def wait_for_text(self, regex: str) -> None: 874 def screen_matches(last: bool) -> bool: 875 variants = self.get_screen_text_variants() 876 for text in variants: 877 if re.search(regex, text) is not None: 878 return True 879 880 if last: 881 self.log("Last OCR attempt failed. Text was: {}".format(variants)) 882 883 return False 884 885 with self.nested("waiting for {} to appear on screen".format(regex)): 886 retry(screen_matches) 887 888 def wait_for_console_text(self, regex: str) -> None: 889 self.log("waiting for {} to appear on console".format(regex)) 890 # Buffer the console output, this is needed 891 # to match multiline regexes. 892 console = io.StringIO() 893 while True: 894 try: 895 console.write(self.last_lines.get()) 896 except queue.Empty: 897 self.sleep(1) 898 continue 899 console.seek(0) 900 matches = re.search(regex, console.read()) 901 if matches is not None: 902 return 903 904 def send_key(self, key: str) -> None: 905 key = CHAR_TO_KEY.get(key, key) 906 self.send_monitor_command("sendkey {}".format(key)) 907 908 def start(self) -> None: 909 if self.booted: 910 return 911 912 self.log("starting vm") 913 914 def clear(path: Path) -> Path: 915 if path.exists(): 916 path.unlink() 917 return path 918 919 def create_socket(path: Path) -> socket.socket: 920 s = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM) 921 s.bind(str(path)) 922 s.listen(1) 923 return s 924 925 monitor_socket = create_socket(clear(self.monitor_path)) 926 shell_socket = create_socket(clear(self.shell_path)) 927 self.process = self.start_command.run( 928 self.state_dir, 929 self.shared_dir, 930 self.monitor_path, 931 self.shell_path, 932 ) 933 self.monitor, _ = monitor_socket.accept() 934 self.shell, _ = shell_socket.accept() 935 936 # Store last serial console lines for use 937 # of wait_for_console_text 938 self.last_lines: Queue = Queue() 939 940 def process_serial_output() -> None: 941 assert self.process 942 assert self.process.stdout 943 for _line in self.process.stdout: 944 # Ignore undecodable bytes that may occur in boot menus 945 line = _line.decode(errors="ignore").replace("\r", "").rstrip() 946 self.last_lines.put(line) 947 self.log_serial(line) 948 949 self.serial_thread = threading.Thread(target=process_serial_output) 950 self.serial_thread.start() 951 952 self.wait_for_monitor_prompt() 953 954 self.pid = self.process.pid 955 self.booted = True 956 957 self.log("QEMU running (pid {})".format(self.pid)) 958 959 def cleanup_statedir(self) -> None: 960 shutil.rmtree(self.state_dir) 961 rootlog.log(f"deleting VM state directory {self.state_dir}") 962 rootlog.log("if you want to keep the VM state, pass --keep-vm-state") 963 964 def shutdown(self) -> None: 965 if not self.booted: 966 return 967 968 assert self.shell 969 self.shell.send("poweroff\n".encode()) 970 self.wait_for_shutdown() 971 972 def crash(self) -> None: 973 if not self.booted: 974 return 975 976 self.log("forced crash") 977 self.send_monitor_command("quit") 978 self.wait_for_shutdown() 979 980 def wait_for_x(self) -> None: 981 """Wait until it is possible to connect to the X server. Note that 982 testing the existence of /tmp/.X11-unix/X0 is insufficient. 983 """ 984 985 def check_x(_: Any) -> bool: 986 cmd = ( 987 "journalctl -b SYSLOG_IDENTIFIER=systemd | " 988 + 'grep "Reached target Current graphical"' 989 ) 990 status, _ = self.execute(cmd) 991 if status != 0: 992 return False 993 status, _ = self.execute("[ -e /tmp/.X11-unix/X0 ]") 994 return status == 0 995 996 with self.nested("waiting for the X11 server"): 997 retry(check_x) 998 999 def get_window_names(self) -> List[str]: 1000 return self.succeed( 1001 r"xwininfo -root -tree | sed 's/.*0x[0-9a-f]* \"\([^\"]*\)\".*/\1/; t; d'" 1002 ).splitlines() 1003 1004 def wait_for_window(self, regexp: str) -> None: 1005 pattern = re.compile(regexp) 1006 1007 def window_is_visible(last_try: bool) -> bool: 1008 names = self.get_window_names() 1009 if last_try: 1010 self.log( 1011 "Last chance to match {} on the window list,".format(regexp) 1012 + " which currently contains: " 1013 + ", ".join(names) 1014 ) 1015 return any(pattern.search(name) for name in names) 1016 1017 with self.nested("Waiting for a window to appear"): 1018 retry(window_is_visible) 1019 1020 def sleep(self, secs: int) -> None: 1021 # We want to sleep in *guest* time, not *host* time. 1022 self.succeed(f"sleep {secs}") 1023 1024 def forward_port(self, host_port: int = 8080, guest_port: int = 80) -> None: 1025 """Forward a TCP port on the host to a TCP port on the guest. 1026 Useful during interactive testing. 1027 """ 1028 self.send_monitor_command( 1029 "hostfwd_add tcp::{}-:{}".format(host_port, guest_port) 1030 ) 1031 1032 def block(self) -> None: 1033 """Make the machine unreachable by shutting down eth1 (the multicast 1034 interface used to talk to the other VMs). We keep eth0 up so that 1035 the test driver can continue to talk to the machine. 1036 """ 1037 self.send_monitor_command("set_link virtio-net-pci.1 off") 1038 1039 def unblock(self) -> None: 1040 """Make the machine reachable.""" 1041 self.send_monitor_command("set_link virtio-net-pci.1 on") 1042 1043 def release(self) -> None: 1044 if self.pid is None: 1045 return 1046 rootlog.info(f"kill machine (pid {self.pid})") 1047 assert self.process 1048 assert self.shell 1049 assert self.monitor 1050 assert self.serial_thread 1051 1052 self.process.terminate() 1053 self.shell.close() 1054 self.monitor.close() 1055 self.serial_thread.join() 1056 1057 1058class VLan: 1059 """This class handles a VLAN that the run-vm scripts identify via its 1060 number handles. The network's lifetime equals the object's lifetime. 1061 """ 1062 1063 nr: int 1064 socket_dir: Path 1065 1066 process: subprocess.Popen 1067 pid: int 1068 fd: io.TextIOBase 1069 1070 def __repr__(self) -> str: 1071 return f"<Vlan Nr. {self.nr}>" 1072 1073 def __init__(self, nr: int, tmp_dir: Path): 1074 self.nr = nr 1075 self.socket_dir = tmp_dir / f"vde{self.nr}.ctl" 1076 1077 # TODO: don't side-effect environment here 1078 os.environ[f"QEMU_VDE_SOCKET_{self.nr}"] = str(self.socket_dir) 1079 1080 rootlog.info("start vlan") 1081 pty_master, pty_slave = pty.openpty() 1082 1083 self.process = subprocess.Popen( 1084 ["vde_switch", "-s", self.socket_dir, "--dirmode", "0700"], 1085 stdin=pty_slave, 1086 stdout=subprocess.PIPE, 1087 stderr=subprocess.PIPE, 1088 shell=False, 1089 ) 1090 self.pid = self.process.pid 1091 self.fd = os.fdopen(pty_master, "w") 1092 self.fd.write("version\n") 1093 1094 # TODO: perl version checks if this can be read from 1095 # an if not, dies. we could hang here forever. Fix it. 1096 assert self.process.stdout is not None 1097 self.process.stdout.readline() 1098 if not (self.socket_dir / "ctl").exists(): 1099 rootlog.error("cannot start vde_switch") 1100 1101 rootlog.info(f"running vlan (pid {self.pid})") 1102 1103 def __del__(self) -> None: 1104 rootlog.info(f"kill vlan (pid {self.pid})") 1105 self.fd.close() 1106 self.process.terminate() 1107 1108 1109class Driver: 1110 """A handle to the driver that sets up the environment 1111 and runs the tests""" 1112 1113 tests: str 1114 vlans: List[VLan] 1115 machines: List[Machine] 1116 1117 def __init__( 1118 self, 1119 start_scripts: List[str], 1120 vlans: List[int], 1121 tests: str, 1122 keep_vm_state: bool = False, 1123 ): 1124 self.tests = tests 1125 1126 tmp_dir = Path(os.environ.get("TMPDIR", tempfile.gettempdir())) 1127 tmp_dir.mkdir(mode=0o700, exist_ok=True) 1128 1129 with rootlog.nested("start all VLans"): 1130 self.vlans = [VLan(nr, tmp_dir) for nr in vlans] 1131 1132 def cmd(scripts: List[str]) -> Iterator[NixStartScript]: 1133 for s in scripts: 1134 yield NixStartScript(s) 1135 1136 self.machines = [ 1137 Machine( 1138 start_command=cmd, 1139 keep_vm_state=keep_vm_state, 1140 name=cmd.machine_name, 1141 tmp_dir=tmp_dir, 1142 ) 1143 for cmd in cmd(start_scripts) 1144 ] 1145 1146 def __enter__(self) -> "Driver": 1147 return self 1148 1149 def __exit__(self, *_: Any) -> None: 1150 with rootlog.nested("cleanup"): 1151 for machine in self.machines: 1152 machine.release() 1153 1154 def subtest(self, name: str) -> Iterator[None]: 1155 """Group logs under a given test name""" 1156 with rootlog.nested(name): 1157 try: 1158 yield 1159 return True 1160 except Exception as e: 1161 rootlog.error(f'Test "{name}" failed with error: "{e}"') 1162 raise e 1163 1164 def test_symbols(self) -> Dict[str, Any]: 1165 @contextmanager 1166 def subtest(name: str) -> Iterator[None]: 1167 return self.subtest(name) 1168 1169 general_symbols = dict( 1170 start_all=self.start_all, 1171 test_script=self.test_script, 1172 machines=self.machines, 1173 vlans=self.vlans, 1174 driver=self, 1175 log=rootlog, 1176 os=os, 1177 create_machine=self.create_machine, 1178 subtest=subtest, 1179 run_tests=self.run_tests, 1180 join_all=self.join_all, 1181 retry=retry, 1182 serial_stdout_off=self.serial_stdout_off, 1183 serial_stdout_on=self.serial_stdout_on, 1184 Machine=Machine, # for typing 1185 ) 1186 machine_symbols = {m.name: m for m in self.machines} 1187 # If there's exactly one machine, make it available under the name 1188 # "machine", even if it's not called that. 1189 if len(self.machines) == 1: 1190 (machine_symbols["machine"],) = self.machines 1191 vlan_symbols = { 1192 f"vlan{v.nr}": self.vlans[idx] for idx, v in enumerate(self.vlans) 1193 } 1194 print( 1195 "additionally exposed symbols:\n " 1196 + ", ".join(map(lambda m: m.name, self.machines)) 1197 + ",\n " 1198 + ", ".join(map(lambda v: f"vlan{v.nr}", self.vlans)) 1199 + ",\n " 1200 + ", ".join(list(general_symbols.keys())) 1201 ) 1202 return {**general_symbols, **machine_symbols, **vlan_symbols} 1203 1204 def test_script(self) -> None: 1205 """Run the test script""" 1206 with rootlog.nested("run the VM test script"): 1207 symbols = self.test_symbols() # call eagerly 1208 exec(self.tests, symbols, None) 1209 1210 def run_tests(self) -> None: 1211 """Run the test script (for non-interactive test runs)""" 1212 self.test_script() 1213 # TODO: Collect coverage data 1214 for machine in self.machines: 1215 if machine.is_up(): 1216 machine.execute("sync") 1217 1218 def start_all(self) -> None: 1219 """Start all machines""" 1220 with rootlog.nested("start all VMs"): 1221 for machine in self.machines: 1222 machine.start() 1223 1224 def join_all(self) -> None: 1225 """Wait for all machines to shut down""" 1226 with rootlog.nested("wait for all VMs to finish"): 1227 for machine in self.machines: 1228 machine.wait_for_shutdown() 1229 1230 def create_machine(self, args: Dict[str, Any]) -> Machine: 1231 rootlog.warning( 1232 "Using legacy create_machine(), please instantiate the" 1233 "Machine class directly, instead" 1234 ) 1235 tmp_dir = Path(os.environ.get("TMPDIR", tempfile.gettempdir())) 1236 tmp_dir.mkdir(mode=0o700, exist_ok=True) 1237 1238 if args.get("startCommand"): 1239 start_command: str = args.get("startCommand", "") 1240 cmd = NixStartScript(start_command) 1241 name = args.get("name", cmd.machine_name) 1242 else: 1243 cmd = Machine.create_startcommand(args) # type: ignore 1244 name = args.get("name", "machine") 1245 1246 return Machine( 1247 tmp_dir=tmp_dir, 1248 start_command=cmd, 1249 name=name, 1250 keep_vm_state=args.get("keep_vm_state", False), 1251 allow_reboot=args.get("allow_reboot", False), 1252 ) 1253 1254 def serial_stdout_on(self) -> None: 1255 rootlog._print_serial_logs = True 1256 1257 def serial_stdout_off(self) -> None: 1258 rootlog._print_serial_logs = False 1259 1260 1261class EnvDefault(argparse.Action): 1262 """An argpars Action that takes values from the specified 1263 environment variable as the flags default value. 1264 """ 1265 1266 def __init__(self, envvar, required=False, default=None, nargs=None, **kwargs): # type: ignore 1267 if not default and envvar: 1268 if envvar in os.environ: 1269 if nargs is not None and (nargs.isdigit() or nargs in ["*", "+"]): 1270 default = os.environ[envvar].split() 1271 else: 1272 default = os.environ[envvar] 1273 kwargs["help"] = ( 1274 kwargs["help"] + f" (default from environment: {default})" 1275 ) 1276 if required and default: 1277 required = False 1278 super(EnvDefault, self).__init__( 1279 default=default, required=required, nargs=nargs, **kwargs 1280 ) 1281 1282 def __call__(self, parser, namespace, values, option_string=None): # type: ignore 1283 setattr(namespace, self.dest, values) 1284 1285 1286if __name__ == "__main__": 1287 arg_parser = argparse.ArgumentParser(prog="nixos-test-driver") 1288 arg_parser.add_argument( 1289 "-K", 1290 "--keep-vm-state", 1291 help="re-use a VM state coming from a previous run", 1292 action="store_true", 1293 ) 1294 arg_parser.add_argument( 1295 "-I", 1296 "--interactive", 1297 help="drop into a python repl and run the tests interactively", 1298 action="store_true", 1299 ) 1300 arg_parser.add_argument( 1301 "--start-scripts", 1302 metavar="START-SCRIPT", 1303 action=EnvDefault, 1304 envvar="startScripts", 1305 nargs="*", 1306 help="start scripts for participating virtual machines", 1307 ) 1308 arg_parser.add_argument( 1309 "--vlans", 1310 metavar="VLAN", 1311 action=EnvDefault, 1312 envvar="vlans", 1313 nargs="*", 1314 help="vlans to span by the driver", 1315 ) 1316 arg_parser.add_argument( 1317 "testscript", 1318 action=EnvDefault, 1319 envvar="testScript", 1320 help="the test script to run", 1321 type=Path, 1322 ) 1323 1324 args = arg_parser.parse_args() 1325 1326 if not args.keep_vm_state: 1327 rootlog.info("Machine state will be reset. To keep it, pass --keep-vm-state") 1328 1329 with Driver( 1330 args.start_scripts, args.vlans, args.testscript.read_text(), args.keep_vm_state 1331 ) as driver: 1332 if args.interactive: 1333 ptpython.repl.embed(driver.test_symbols(), {}) 1334 else: 1335 tic = time.time() 1336 driver.run_tests() 1337 toc = time.time() 1338 rootlog.info(f"test script finished in {(toc-tic):.2f}s")