1#! /somewhere/python3
2from contextlib import contextmanager, _GeneratorContextManager
3from queue import Queue, Empty
4from typing import Tuple, Any, Callable, Dict, Iterator, Optional, List, Iterable
5from xml.sax.saxutils import XMLGenerator
6from colorama import Style
7from pathlib import Path
8import queue
9import io
10import threading
11import argparse
12import base64
13import codecs
14import os
15import ptpython.repl
16import pty
17import re
18import shlex
19import shutil
20import socket
21import subprocess
22import sys
23import tempfile
24import time
25import unicodedata
26
27CHAR_TO_KEY = {
28 "A": "shift-a",
29 "N": "shift-n",
30 "-": "0x0C",
31 "_": "shift-0x0C",
32 "B": "shift-b",
33 "O": "shift-o",
34 "=": "0x0D",
35 "+": "shift-0x0D",
36 "C": "shift-c",
37 "P": "shift-p",
38 "[": "0x1A",
39 "{": "shift-0x1A",
40 "D": "shift-d",
41 "Q": "shift-q",
42 "]": "0x1B",
43 "}": "shift-0x1B",
44 "E": "shift-e",
45 "R": "shift-r",
46 ";": "0x27",
47 ":": "shift-0x27",
48 "F": "shift-f",
49 "S": "shift-s",
50 "'": "0x28",
51 '"': "shift-0x28",
52 "G": "shift-g",
53 "T": "shift-t",
54 "`": "0x29",
55 "~": "shift-0x29",
56 "H": "shift-h",
57 "U": "shift-u",
58 "\\": "0x2B",
59 "|": "shift-0x2B",
60 "I": "shift-i",
61 "V": "shift-v",
62 ",": "0x33",
63 "<": "shift-0x33",
64 "J": "shift-j",
65 "W": "shift-w",
66 ".": "0x34",
67 ">": "shift-0x34",
68 "K": "shift-k",
69 "X": "shift-x",
70 "/": "0x35",
71 "?": "shift-0x35",
72 "L": "shift-l",
73 "Y": "shift-y",
74 " ": "spc",
75 "M": "shift-m",
76 "Z": "shift-z",
77 "\n": "ret",
78 "!": "shift-0x02",
79 "@": "shift-0x03",
80 "#": "shift-0x04",
81 "$": "shift-0x05",
82 "%": "shift-0x06",
83 "^": "shift-0x07",
84 "&": "shift-0x08",
85 "*": "shift-0x09",
86 "(": "shift-0x0A",
87 ")": "shift-0x0B",
88}
89
90
91class Logger:
92 def __init__(self) -> None:
93 self.logfile = os.environ.get("LOGFILE", "/dev/null")
94 self.logfile_handle = codecs.open(self.logfile, "wb")
95 self.xml = XMLGenerator(self.logfile_handle, encoding="utf-8")
96 self.queue: "Queue[Dict[str, str]]" = Queue()
97
98 self.xml.startDocument()
99 self.xml.startElement("logfile", attrs={})
100
101 self._print_serial_logs = True
102
103 @staticmethod
104 def _eprint(*args: object, **kwargs: Any) -> None:
105 print(*args, file=sys.stderr, **kwargs)
106
107 def close(self) -> None:
108 self.xml.endElement("logfile")
109 self.xml.endDocument()
110 self.logfile_handle.close()
111
112 def sanitise(self, message: str) -> str:
113 return "".join(ch for ch in message if unicodedata.category(ch)[0] != "C")
114
115 def maybe_prefix(self, message: str, attributes: Dict[str, str]) -> str:
116 if "machine" in attributes:
117 return "{}: {}".format(attributes["machine"], message)
118 return message
119
120 def log_line(self, message: str, attributes: Dict[str, str]) -> None:
121 self.xml.startElement("line", attributes)
122 self.xml.characters(message)
123 self.xml.endElement("line")
124
125 def info(self, *args, **kwargs) -> None: # type: ignore
126 self.log(*args, **kwargs)
127
128 def warning(self, *args, **kwargs) -> None: # type: ignore
129 self.log(*args, **kwargs)
130
131 def error(self, *args, **kwargs) -> None: # type: ignore
132 self.log(*args, **kwargs)
133 sys.exit(1)
134
135 def log(self, message: str, attributes: Dict[str, str] = {}) -> None:
136 self._eprint(self.maybe_prefix(message, attributes))
137 self.drain_log_queue()
138 self.log_line(message, attributes)
139
140 def log_serial(self, message: str, machine: str) -> None:
141 self.enqueue({"msg": message, "machine": machine, "type": "serial"})
142 if self._print_serial_logs:
143 self._eprint(
144 Style.DIM + "{} # {}".format(machine, message) + Style.RESET_ALL
145 )
146
147 def enqueue(self, item: Dict[str, str]) -> None:
148 self.queue.put(item)
149
150 def drain_log_queue(self) -> None:
151 try:
152 while True:
153 item = self.queue.get_nowait()
154 msg = self.sanitise(item["msg"])
155 del item["msg"]
156 self.log_line(msg, item)
157 except Empty:
158 pass
159
160 @contextmanager
161 def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
162 self._eprint(self.maybe_prefix(message, attributes))
163
164 self.xml.startElement("nest", attrs={})
165 self.xml.startElement("head", attributes)
166 self.xml.characters(message)
167 self.xml.endElement("head")
168
169 tic = time.time()
170 self.drain_log_queue()
171 yield
172 self.drain_log_queue()
173 toc = time.time()
174 self.log("({:.2f} seconds)".format(toc - tic))
175
176 self.xml.endElement("nest")
177
178
179rootlog = Logger()
180
181
182def make_command(args: list) -> str:
183 return " ".join(map(shlex.quote, (map(str, args))))
184
185
186def retry(fn: Callable, timeout: int = 900) -> None:
187 """Call the given function repeatedly, with 1 second intervals,
188 until it returns True or a timeout is reached.
189 """
190
191 for _ in range(timeout):
192 if fn(False):
193 return
194 time.sleep(1)
195
196 if not fn(True):
197 raise Exception(f"action timed out after {timeout} seconds")
198
199
200def _perform_ocr_on_screenshot(
201 screenshot_path: str, model_ids: Iterable[int]
202) -> List[str]:
203 if shutil.which("tesseract") is None:
204 raise Exception("OCR requested but enableOCR is false")
205
206 magick_args = (
207 "-filter Catrom -density 72 -resample 300 "
208 + "-contrast -normalize -despeckle -type grayscale "
209 + "-sharpen 1 -posterize 3 -negate -gamma 100 "
210 + "-blur 1x65535"
211 )
212
213 tess_args = f"-c debug_file=/dev/null --psm 11"
214
215 cmd = f"convert {magick_args} {screenshot_path} tiff:{screenshot_path}.tiff"
216 ret = subprocess.run(cmd, shell=True, capture_output=True)
217 if ret.returncode != 0:
218 raise Exception(f"TIFF conversion failed with exit code {ret.returncode}")
219
220 model_results = []
221 for model_id in model_ids:
222 cmd = f"tesseract {screenshot_path}.tiff - {tess_args} --oem {model_id}"
223 ret = subprocess.run(cmd, shell=True, capture_output=True)
224 if ret.returncode != 0:
225 raise Exception(f"OCR failed with exit code {ret.returncode}")
226 model_results.append(ret.stdout.decode("utf-8"))
227
228 return model_results
229
230
231class StartCommand:
232 """The Base Start Command knows how to append the necesary
233 runtime qemu options as determined by a particular test driver
234 run. Any such start command is expected to happily receive and
235 append additional qemu args.
236 """
237
238 _cmd: str
239
240 def cmd(
241 self,
242 monitor_socket_path: Path,
243 shell_socket_path: Path,
244 allow_reboot: bool = False, # TODO: unused, legacy?
245 ) -> str:
246 display_opts = ""
247 display_available = any(x in os.environ for x in ["DISPLAY", "WAYLAND_DISPLAY"])
248 if not display_available:
249 display_opts += " -nographic"
250
251 # qemu options
252 qemu_opts = ""
253 qemu_opts += (
254 ""
255 if allow_reboot
256 else " -no-reboot"
257 " -device virtio-serial"
258 " -device virtconsole,chardev=shell"
259 " -device virtio-rng-pci"
260 " -serial stdio"
261 )
262 # TODO: qemu script already catpures this env variable, legacy?
263 qemu_opts += " " + os.environ.get("QEMU_OPTS", "")
264
265 return (
266 f"{self._cmd}"
267 f" -monitor unix:{monitor_socket_path}"
268 f" -chardev socket,id=shell,path={shell_socket_path}"
269 f"{qemu_opts}"
270 f"{display_opts}"
271 )
272
273 @staticmethod
274 def build_environment(
275 state_dir: Path,
276 shared_dir: Path,
277 ) -> dict:
278 # We make a copy to not update the current environment
279 env = dict(os.environ)
280 env.update(
281 {
282 "TMPDIR": str(state_dir),
283 "SHARED_DIR": str(shared_dir),
284 "USE_TMPDIR": "1",
285 }
286 )
287 return env
288
289 def run(
290 self,
291 state_dir: Path,
292 shared_dir: Path,
293 monitor_socket_path: Path,
294 shell_socket_path: Path,
295 ) -> subprocess.Popen:
296 return subprocess.Popen(
297 self.cmd(monitor_socket_path, shell_socket_path),
298 stdin=subprocess.DEVNULL,
299 stdout=subprocess.PIPE,
300 stderr=subprocess.STDOUT,
301 shell=True,
302 cwd=state_dir,
303 env=self.build_environment(state_dir, shared_dir),
304 )
305
306
307class NixStartScript(StartCommand):
308 """A start script from nixos/modules/virtualiation/qemu-vm.nix
309 that also satisfies the requirement of the BaseStartCommand.
310 These Nix commands have the particular charactersitic that the
311 machine name can be extracted out of them via a regex match.
312 (Admittedly a _very_ implicit contract, evtl. TODO fix)
313 """
314
315 def __init__(self, script: str):
316 self._cmd = script
317
318 @property
319 def machine_name(self) -> str:
320 match = re.search("run-(.+)-vm$", self._cmd)
321 name = "machine"
322 if match:
323 name = match.group(1)
324 return name
325
326
327class LegacyStartCommand(StartCommand):
328 """Used in some places to create an ad-hoc machine instead of
329 using nix test instrumentation + module system for that purpose.
330 Legacy.
331 """
332
333 def __init__(
334 self,
335 netBackendArgs: Optional[str] = None,
336 netFrontendArgs: Optional[str] = None,
337 hda: Optional[Tuple[Path, str]] = None,
338 cdrom: Optional[str] = None,
339 usb: Optional[str] = None,
340 bios: Optional[str] = None,
341 qemuFlags: Optional[str] = None,
342 ):
343 self._cmd = "qemu-kvm -m 384"
344
345 # networking
346 net_backend = "-netdev user,id=net0"
347 net_frontend = "-device virtio-net-pci,netdev=net0"
348 if netBackendArgs is not None:
349 net_backend += "," + netBackendArgs
350 if netFrontendArgs is not None:
351 net_frontend += "," + netFrontendArgs
352 self._cmd += f" {net_backend} {net_frontend}"
353
354 # hda
355 hda_cmd = ""
356 if hda is not None:
357 hda_path = hda[0].resolve()
358 hda_interface = hda[1]
359 if hda_interface == "scsi":
360 hda_cmd += (
361 f" -drive id=hda,file={hda_path},werror=report,if=none"
362 " -device scsi-hd,drive=hda"
363 )
364 else:
365 hda_cmd += f" -drive file={hda_path},if={hda_interface},werror=report"
366 self._cmd += hda_cmd
367
368 # cdrom
369 if cdrom is not None:
370 self._cmd += f" -cdrom {cdrom}"
371
372 # usb
373 usb_cmd = ""
374 if usb is not None:
375 # https://github.com/qemu/qemu/blob/master/docs/usb2.txt
376 usb_cmd += (
377 " -device usb-ehci"
378 f" -drive id=usbdisk,file={usb},if=none,readonly"
379 " -device usb-storage,drive=usbdisk "
380 )
381 self._cmd += usb_cmd
382
383 # bios
384 if bios is not None:
385 self._cmd += f" -bios {bios}"
386
387 # qemu flags
388 if qemuFlags is not None:
389 self._cmd += f" {qemuFlags}"
390
391
392class Machine:
393 """A handle to the machine with this name, that also knows how to manage
394 the machine lifecycle with the help of a start script / command."""
395
396 name: str
397 tmp_dir: Path
398 shared_dir: Path
399 state_dir: Path
400 monitor_path: Path
401 shell_path: Path
402
403 start_command: StartCommand
404 keep_vm_state: bool
405 allow_reboot: bool
406
407 process: Optional[subprocess.Popen]
408 pid: Optional[int]
409 monitor: Optional[socket.socket]
410 shell: Optional[socket.socket]
411 serial_thread: Optional[threading.Thread]
412
413 booted: bool
414 connected: bool
415 # Store last serial console lines for use
416 # of wait_for_console_text
417 last_lines: Queue = Queue()
418
419 def __repr__(self) -> str:
420 return f"<Machine '{self.name}'>"
421
422 def __init__(
423 self,
424 tmp_dir: Path,
425 start_command: StartCommand,
426 name: str = "machine",
427 keep_vm_state: bool = False,
428 allow_reboot: bool = False,
429 ) -> None:
430 self.tmp_dir = tmp_dir
431 self.keep_vm_state = keep_vm_state
432 self.allow_reboot = allow_reboot
433 self.name = name
434 self.start_command = start_command
435
436 # set up directories
437 self.shared_dir = self.tmp_dir / "shared-xchg"
438 self.shared_dir.mkdir(mode=0o700, exist_ok=True)
439
440 self.state_dir = self.tmp_dir / f"vm-state-{self.name}"
441 self.monitor_path = self.state_dir / "monitor"
442 self.shell_path = self.state_dir / "shell"
443 if (not self.keep_vm_state) and self.state_dir.exists():
444 self.cleanup_statedir()
445 self.state_dir.mkdir(mode=0o700, exist_ok=True)
446
447 self.process = None
448 self.pid = None
449 self.monitor = None
450 self.shell = None
451 self.serial_thread = None
452
453 self.booted = False
454 self.connected = False
455
456 @staticmethod
457 def create_startcommand(args: Dict[str, str]) -> StartCommand:
458 rootlog.warning(
459 "Using legacy create_startcommand(),"
460 "please use proper nix test vm instrumentation, instead"
461 "to generate the appropriate nixos test vm qemu startup script"
462 )
463 hda = None
464 if args.get("hda"):
465 hda_arg: str = args.get("hda", "")
466 hda_arg_path: Path = Path(hda_arg)
467 hda = (hda_arg_path, args.get("hdaInterface", ""))
468 return LegacyStartCommand(
469 netBackendArgs=args.get("netBackendArgs"),
470 netFrontendArgs=args.get("netFrontendArgs"),
471 hda=hda,
472 cdrom=args.get("cdrom"),
473 usb=args.get("usb"),
474 bios=args.get("bios"),
475 qemuFlags=args.get("qemuFlags"),
476 )
477
478 def is_up(self) -> bool:
479 return self.booted and self.connected
480
481 def log(self, msg: str) -> None:
482 rootlog.log(msg, {"machine": self.name})
483
484 def log_serial(self, msg: str) -> None:
485 rootlog.log_serial(msg, self.name)
486
487 def nested(self, msg: str, attrs: Dict[str, str] = {}) -> _GeneratorContextManager:
488 my_attrs = {"machine": self.name}
489 my_attrs.update(attrs)
490 return rootlog.nested(msg, my_attrs)
491
492 def wait_for_monitor_prompt(self) -> str:
493 assert self.monitor is not None
494 answer = ""
495 while True:
496 undecoded_answer = self.monitor.recv(1024)
497 if not undecoded_answer:
498 break
499 answer += undecoded_answer.decode()
500 if answer.endswith("(qemu) "):
501 break
502 return answer
503
504 def send_monitor_command(self, command: str) -> str:
505 message = ("{}\n".format(command)).encode()
506 self.log("sending monitor command: {}".format(command))
507 assert self.monitor is not None
508 self.monitor.send(message)
509 return self.wait_for_monitor_prompt()
510
511 def wait_for_unit(self, unit: str, user: Optional[str] = None) -> None:
512 """Wait for a systemd unit to get into "active" state.
513 Throws exceptions on "failed" and "inactive" states as well as
514 after timing out.
515 """
516
517 def check_active(_: Any) -> bool:
518 info = self.get_unit_info(unit, user)
519 state = info["ActiveState"]
520 if state == "failed":
521 raise Exception('unit "{}" reached state "{}"'.format(unit, state))
522
523 if state == "inactive":
524 status, jobs = self.systemctl("list-jobs --full 2>&1", user)
525 if "No jobs" in jobs:
526 info = self.get_unit_info(unit, user)
527 if info["ActiveState"] == state:
528 raise Exception(
529 (
530 'unit "{}" is inactive and there ' "are no pending jobs"
531 ).format(unit)
532 )
533
534 return state == "active"
535
536 retry(check_active)
537
538 def get_unit_info(self, unit: str, user: Optional[str] = None) -> Dict[str, str]:
539 status, lines = self.systemctl('--no-pager show "{}"'.format(unit), user)
540 if status != 0:
541 raise Exception(
542 'retrieving systemctl info for unit "{}" {} failed with exit code {}'.format(
543 unit, "" if user is None else 'under user "{}"'.format(user), status
544 )
545 )
546
547 line_pattern = re.compile(r"^([^=]+)=(.*)$")
548
549 def tuple_from_line(line: str) -> Tuple[str, str]:
550 match = line_pattern.match(line)
551 assert match is not None
552 return match[1], match[2]
553
554 return dict(
555 tuple_from_line(line)
556 for line in lines.split("\n")
557 if line_pattern.match(line)
558 )
559
560 def systemctl(self, q: str, user: Optional[str] = None) -> Tuple[int, str]:
561 if user is not None:
562 q = q.replace("'", "\\'")
563 return self.execute(
564 (
565 "su -l {} --shell /bin/sh -c "
566 "$'XDG_RUNTIME_DIR=/run/user/`id -u` "
567 "systemctl --user {}'"
568 ).format(user, q)
569 )
570 return self.execute("systemctl {}".format(q))
571
572 def require_unit_state(self, unit: str, require_state: str = "active") -> None:
573 with self.nested(
574 "checking if unit ‘{}’ has reached state '{}'".format(unit, require_state)
575 ):
576 info = self.get_unit_info(unit)
577 state = info["ActiveState"]
578 if state != require_state:
579 raise Exception(
580 "Expected unit ‘{}’ to to be in state ".format(unit)
581 + "'{}' but it is in state ‘{}’".format(require_state, state)
582 )
583
584 def _next_newline_closed_block_from_shell(self) -> str:
585 assert self.shell
586 output_buffer = []
587 while True:
588 # This receives up to 4096 bytes from the socket
589 chunk = self.shell.recv(4096)
590 if not chunk:
591 # Probably a broken pipe, return the output we have
592 break
593
594 decoded = chunk.decode()
595 output_buffer += [decoded]
596 if decoded[-1] == "\n":
597 break
598 return "".join(output_buffer)
599
600 def execute(self, command: str, check_return: bool = True) -> Tuple[int, str]:
601 self.connect()
602
603 out_command = f"( set -euo pipefail; {command} ) | (base64 --wrap 0; echo)\n"
604 assert self.shell
605 self.shell.send(out_command.encode())
606
607 # Get the output
608 output = base64.b64decode(self._next_newline_closed_block_from_shell())
609
610 if not check_return:
611 return (-1, output.decode())
612
613 # Get the return code
614 self.shell.send("echo ${PIPESTATUS[0]}\n".encode())
615 rc = int(self._next_newline_closed_block_from_shell().strip())
616
617 return (rc, output.decode())
618
619 def shell_interact(self) -> None:
620 """Allows you to interact with the guest shell
621
622 Should only be used during test development, not in the production test."""
623 self.connect()
624 self.log("Terminal is ready (there is no prompt):")
625
626 assert self.shell
627 subprocess.run(
628 ["socat", "READLINE", f"FD:{self.shell.fileno()}"],
629 pass_fds=[self.shell.fileno()],
630 )
631
632 def succeed(self, *commands: str) -> str:
633 """Execute each command and check that it succeeds."""
634 output = ""
635 for command in commands:
636 with self.nested("must succeed: {}".format(command)):
637 (status, out) = self.execute(command)
638 if status != 0:
639 self.log("output: {}".format(out))
640 raise Exception(
641 "command `{}` failed (exit code {})".format(command, status)
642 )
643 output += out
644 return output
645
646 def fail(self, *commands: str) -> str:
647 """Execute each command and check that it fails."""
648 output = ""
649 for command in commands:
650 with self.nested("must fail: {}".format(command)):
651 (status, out) = self.execute(command)
652 if status == 0:
653 raise Exception(
654 "command `{}` unexpectedly succeeded".format(command)
655 )
656 output += out
657 return output
658
659 def wait_until_succeeds(self, command: str, timeout: int = 900) -> str:
660 """Wait until a command returns success and return its output.
661 Throws an exception on timeout.
662 """
663 output = ""
664
665 def check_success(_: Any) -> bool:
666 nonlocal output
667 status, output = self.execute(command)
668 return status == 0
669
670 with self.nested("waiting for success: {}".format(command)):
671 retry(check_success, timeout)
672 return output
673
674 def wait_until_fails(self, command: str) -> str:
675 """Wait until a command returns failure.
676 Throws an exception on timeout.
677 """
678 output = ""
679
680 def check_failure(_: Any) -> bool:
681 nonlocal output
682 status, output = self.execute(command)
683 return status != 0
684
685 with self.nested("waiting for failure: {}".format(command)):
686 retry(check_failure)
687 return output
688
689 def wait_for_shutdown(self) -> None:
690 if not self.booted:
691 return
692
693 with self.nested("waiting for the VM to power off"):
694 sys.stdout.flush()
695 assert self.process
696 self.process.wait()
697
698 self.pid = None
699 self.booted = False
700 self.connected = False
701
702 def get_tty_text(self, tty: str) -> str:
703 status, output = self.execute(
704 "fold -w$(stty -F /dev/tty{0} size | "
705 "awk '{{print $2}}') /dev/vcs{0}".format(tty)
706 )
707 return output
708
709 def wait_until_tty_matches(self, tty: str, regexp: str) -> None:
710 """Wait until the visible output on the chosen TTY matches regular
711 expression. Throws an exception on timeout.
712 """
713 matcher = re.compile(regexp)
714
715 def tty_matches(last: bool) -> bool:
716 text = self.get_tty_text(tty)
717 if last:
718 self.log(
719 f"Last chance to match /{regexp}/ on TTY{tty}, "
720 f"which currently contains: {text}"
721 )
722 return len(matcher.findall(text)) > 0
723
724 with self.nested("waiting for {} to appear on tty {}".format(regexp, tty)):
725 retry(tty_matches)
726
727 def send_chars(self, chars: List[str]) -> None:
728 with self.nested("sending keys ‘{}‘".format(chars)):
729 for char in chars:
730 self.send_key(char)
731
732 def wait_for_file(self, filename: str) -> None:
733 """Waits until the file exists in machine's file system."""
734
735 def check_file(_: Any) -> bool:
736 status, _ = self.execute("test -e {}".format(filename))
737 return status == 0
738
739 with self.nested("waiting for file ‘{}‘".format(filename)):
740 retry(check_file)
741
742 def wait_for_open_port(self, port: int) -> None:
743 def port_is_open(_: Any) -> bool:
744 status, _ = self.execute("nc -z localhost {}".format(port))
745 return status == 0
746
747 with self.nested("waiting for TCP port {}".format(port)):
748 retry(port_is_open)
749
750 def wait_for_closed_port(self, port: int) -> None:
751 def port_is_closed(_: Any) -> bool:
752 status, _ = self.execute("nc -z localhost {}".format(port))
753 return status != 0
754
755 retry(port_is_closed)
756
757 def start_job(self, jobname: str, user: Optional[str] = None) -> Tuple[int, str]:
758 return self.systemctl("start {}".format(jobname), user)
759
760 def stop_job(self, jobname: str, user: Optional[str] = None) -> Tuple[int, str]:
761 return self.systemctl("stop {}".format(jobname), user)
762
763 def wait_for_job(self, jobname: str) -> None:
764 self.wait_for_unit(jobname)
765
766 def connect(self) -> None:
767 if self.connected:
768 return
769
770 with self.nested("waiting for the VM to finish booting"):
771 self.start()
772
773 assert self.shell
774
775 tic = time.time()
776 self.shell.recv(1024)
777 # TODO: Timeout
778 toc = time.time()
779
780 self.log("connected to guest root shell")
781 self.log("(connecting took {:.2f} seconds)".format(toc - tic))
782 self.connected = True
783
784 def screenshot(self, filename: str) -> None:
785 out_dir = os.environ.get("out", os.getcwd())
786 word_pattern = re.compile(r"^\w+$")
787 if word_pattern.match(filename):
788 filename = os.path.join(out_dir, "{}.png".format(filename))
789 tmp = "{}.ppm".format(filename)
790
791 with self.nested(
792 "making screenshot {}".format(filename),
793 {"image": os.path.basename(filename)},
794 ):
795 self.send_monitor_command("screendump {}".format(tmp))
796 ret = subprocess.run("pnmtopng {} > {}".format(tmp, filename), shell=True)
797 os.unlink(tmp)
798 if ret.returncode != 0:
799 raise Exception("Cannot convert screenshot")
800
801 def copy_from_host_via_shell(self, source: str, target: str) -> None:
802 """Copy a file from the host into the guest by piping it over the
803 shell into the destination file. Works without host-guest shared folder.
804 Prefer copy_from_host for whenever possible.
805 """
806 with open(source, "rb") as fh:
807 content_b64 = base64.b64encode(fh.read()).decode()
808 self.succeed(
809 f"mkdir -p $(dirname {target})",
810 f"echo -n {content_b64} | base64 -d > {target}",
811 )
812
813 def copy_from_host(self, source: str, target: str) -> None:
814 """Copy a file from the host into the guest via the `shared_dir` shared
815 among all the VMs (using a temporary directory).
816 """
817 host_src = Path(source)
818 vm_target = Path(target)
819 with tempfile.TemporaryDirectory(dir=self.shared_dir) as shared_td:
820 shared_temp = Path(shared_td)
821 host_intermediate = shared_temp / host_src.name
822 vm_shared_temp = Path("/tmp/shared") / shared_temp.name
823 vm_intermediate = vm_shared_temp / host_src.name
824
825 self.succeed(make_command(["mkdir", "-p", vm_shared_temp]))
826 if host_src.is_dir():
827 shutil.copytree(host_src, host_intermediate)
828 else:
829 shutil.copy(host_src, host_intermediate)
830 self.succeed(make_command(["mkdir", "-p", vm_target.parent]))
831 self.succeed(make_command(["cp", "-r", vm_intermediate, vm_target]))
832
833 def copy_from_vm(self, source: str, target_dir: str = "") -> None:
834 """Copy a file from the VM (specified by an in-VM source path) to a path
835 relative to `$out`. The file is copied via the `shared_dir` shared among
836 all the VMs (using a temporary directory).
837 """
838 # Compute the source, target, and intermediate shared file names
839 out_dir = Path(os.environ.get("out", os.getcwd()))
840 vm_src = Path(source)
841 with tempfile.TemporaryDirectory(dir=self.shared_dir) as shared_td:
842 shared_temp = Path(shared_td)
843 vm_shared_temp = Path("/tmp/shared") / shared_temp.name
844 vm_intermediate = vm_shared_temp / vm_src.name
845 intermediate = shared_temp / vm_src.name
846 # Copy the file to the shared directory inside VM
847 self.succeed(make_command(["mkdir", "-p", vm_shared_temp]))
848 self.succeed(make_command(["cp", "-r", vm_src, vm_intermediate]))
849 abs_target = out_dir / target_dir / vm_src.name
850 abs_target.parent.mkdir(exist_ok=True, parents=True)
851 # Copy the file from the shared directory outside VM
852 if intermediate.is_dir():
853 shutil.copytree(intermediate, abs_target)
854 else:
855 shutil.copy(intermediate, abs_target)
856
857 def dump_tty_contents(self, tty: str) -> None:
858 """Debugging: Dump the contents of the TTY<n>"""
859 self.execute("fold -w 80 /dev/vcs{} | systemd-cat".format(tty))
860
861 def _get_screen_text_variants(self, model_ids: Iterable[int]) -> List[str]:
862 with tempfile.TemporaryDirectory() as tmpdir:
863 screenshot_path = os.path.join(tmpdir, "ppm")
864 self.send_monitor_command(f"screendump {screenshot_path}")
865 return _perform_ocr_on_screenshot(screenshot_path, model_ids)
866
867 def get_screen_text_variants(self) -> List[str]:
868 return self._get_screen_text_variants([0, 1, 2])
869
870 def get_screen_text(self) -> str:
871 return self._get_screen_text_variants([2])[0]
872
873 def wait_for_text(self, regex: str) -> None:
874 def screen_matches(last: bool) -> bool:
875 variants = self.get_screen_text_variants()
876 for text in variants:
877 if re.search(regex, text) is not None:
878 return True
879
880 if last:
881 self.log("Last OCR attempt failed. Text was: {}".format(variants))
882
883 return False
884
885 with self.nested("waiting for {} to appear on screen".format(regex)):
886 retry(screen_matches)
887
888 def wait_for_console_text(self, regex: str) -> None:
889 self.log("waiting for {} to appear on console".format(regex))
890 # Buffer the console output, this is needed
891 # to match multiline regexes.
892 console = io.StringIO()
893 while True:
894 try:
895 console.write(self.last_lines.get())
896 except queue.Empty:
897 self.sleep(1)
898 continue
899 console.seek(0)
900 matches = re.search(regex, console.read())
901 if matches is not None:
902 return
903
904 def send_key(self, key: str) -> None:
905 key = CHAR_TO_KEY.get(key, key)
906 self.send_monitor_command("sendkey {}".format(key))
907
908 def start(self) -> None:
909 if self.booted:
910 return
911
912 self.log("starting vm")
913
914 def clear(path: Path) -> Path:
915 if path.exists():
916 path.unlink()
917 return path
918
919 def create_socket(path: Path) -> socket.socket:
920 s = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM)
921 s.bind(str(path))
922 s.listen(1)
923 return s
924
925 monitor_socket = create_socket(clear(self.monitor_path))
926 shell_socket = create_socket(clear(self.shell_path))
927 self.process = self.start_command.run(
928 self.state_dir,
929 self.shared_dir,
930 self.monitor_path,
931 self.shell_path,
932 )
933 self.monitor, _ = monitor_socket.accept()
934 self.shell, _ = shell_socket.accept()
935
936 # Store last serial console lines for use
937 # of wait_for_console_text
938 self.last_lines: Queue = Queue()
939
940 def process_serial_output() -> None:
941 assert self.process
942 assert self.process.stdout
943 for _line in self.process.stdout:
944 # Ignore undecodable bytes that may occur in boot menus
945 line = _line.decode(errors="ignore").replace("\r", "").rstrip()
946 self.last_lines.put(line)
947 self.log_serial(line)
948
949 self.serial_thread = threading.Thread(target=process_serial_output)
950 self.serial_thread.start()
951
952 self.wait_for_monitor_prompt()
953
954 self.pid = self.process.pid
955 self.booted = True
956
957 self.log("QEMU running (pid {})".format(self.pid))
958
959 def cleanup_statedir(self) -> None:
960 shutil.rmtree(self.state_dir)
961 rootlog.log(f"deleting VM state directory {self.state_dir}")
962 rootlog.log("if you want to keep the VM state, pass --keep-vm-state")
963
964 def shutdown(self) -> None:
965 if not self.booted:
966 return
967
968 assert self.shell
969 self.shell.send("poweroff\n".encode())
970 self.wait_for_shutdown()
971
972 def crash(self) -> None:
973 if not self.booted:
974 return
975
976 self.log("forced crash")
977 self.send_monitor_command("quit")
978 self.wait_for_shutdown()
979
980 def wait_for_x(self) -> None:
981 """Wait until it is possible to connect to the X server. Note that
982 testing the existence of /tmp/.X11-unix/X0 is insufficient.
983 """
984
985 def check_x(_: Any) -> bool:
986 cmd = (
987 "journalctl -b SYSLOG_IDENTIFIER=systemd | "
988 + 'grep "Reached target Current graphical"'
989 )
990 status, _ = self.execute(cmd)
991 if status != 0:
992 return False
993 status, _ = self.execute("[ -e /tmp/.X11-unix/X0 ]")
994 return status == 0
995
996 with self.nested("waiting for the X11 server"):
997 retry(check_x)
998
999 def get_window_names(self) -> List[str]:
1000 return self.succeed(
1001 r"xwininfo -root -tree | sed 's/.*0x[0-9a-f]* \"\([^\"]*\)\".*/\1/; t; d'"
1002 ).splitlines()
1003
1004 def wait_for_window(self, regexp: str) -> None:
1005 pattern = re.compile(regexp)
1006
1007 def window_is_visible(last_try: bool) -> bool:
1008 names = self.get_window_names()
1009 if last_try:
1010 self.log(
1011 "Last chance to match {} on the window list,".format(regexp)
1012 + " which currently contains: "
1013 + ", ".join(names)
1014 )
1015 return any(pattern.search(name) for name in names)
1016
1017 with self.nested("Waiting for a window to appear"):
1018 retry(window_is_visible)
1019
1020 def sleep(self, secs: int) -> None:
1021 # We want to sleep in *guest* time, not *host* time.
1022 self.succeed(f"sleep {secs}")
1023
1024 def forward_port(self, host_port: int = 8080, guest_port: int = 80) -> None:
1025 """Forward a TCP port on the host to a TCP port on the guest.
1026 Useful during interactive testing.
1027 """
1028 self.send_monitor_command(
1029 "hostfwd_add tcp::{}-:{}".format(host_port, guest_port)
1030 )
1031
1032 def block(self) -> None:
1033 """Make the machine unreachable by shutting down eth1 (the multicast
1034 interface used to talk to the other VMs). We keep eth0 up so that
1035 the test driver can continue to talk to the machine.
1036 """
1037 self.send_monitor_command("set_link virtio-net-pci.1 off")
1038
1039 def unblock(self) -> None:
1040 """Make the machine reachable."""
1041 self.send_monitor_command("set_link virtio-net-pci.1 on")
1042
1043 def release(self) -> None:
1044 if self.pid is None:
1045 return
1046 rootlog.info(f"kill machine (pid {self.pid})")
1047 assert self.process
1048 assert self.shell
1049 assert self.monitor
1050 assert self.serial_thread
1051
1052 self.process.terminate()
1053 self.shell.close()
1054 self.monitor.close()
1055 self.serial_thread.join()
1056
1057
1058class VLan:
1059 """This class handles a VLAN that the run-vm scripts identify via its
1060 number handles. The network's lifetime equals the object's lifetime.
1061 """
1062
1063 nr: int
1064 socket_dir: Path
1065
1066 process: subprocess.Popen
1067 pid: int
1068 fd: io.TextIOBase
1069
1070 def __repr__(self) -> str:
1071 return f"<Vlan Nr. {self.nr}>"
1072
1073 def __init__(self, nr: int, tmp_dir: Path):
1074 self.nr = nr
1075 self.socket_dir = tmp_dir / f"vde{self.nr}.ctl"
1076
1077 # TODO: don't side-effect environment here
1078 os.environ[f"QEMU_VDE_SOCKET_{self.nr}"] = str(self.socket_dir)
1079
1080 rootlog.info("start vlan")
1081 pty_master, pty_slave = pty.openpty()
1082
1083 self.process = subprocess.Popen(
1084 ["vde_switch", "-s", self.socket_dir, "--dirmode", "0700"],
1085 stdin=pty_slave,
1086 stdout=subprocess.PIPE,
1087 stderr=subprocess.PIPE,
1088 shell=False,
1089 )
1090 self.pid = self.process.pid
1091 self.fd = os.fdopen(pty_master, "w")
1092 self.fd.write("version\n")
1093
1094 # TODO: perl version checks if this can be read from
1095 # an if not, dies. we could hang here forever. Fix it.
1096 assert self.process.stdout is not None
1097 self.process.stdout.readline()
1098 if not (self.socket_dir / "ctl").exists():
1099 rootlog.error("cannot start vde_switch")
1100
1101 rootlog.info(f"running vlan (pid {self.pid})")
1102
1103 def __del__(self) -> None:
1104 rootlog.info(f"kill vlan (pid {self.pid})")
1105 self.fd.close()
1106 self.process.terminate()
1107
1108
1109class Driver:
1110 """A handle to the driver that sets up the environment
1111 and runs the tests"""
1112
1113 tests: str
1114 vlans: List[VLan]
1115 machines: List[Machine]
1116
1117 def __init__(
1118 self,
1119 start_scripts: List[str],
1120 vlans: List[int],
1121 tests: str,
1122 keep_vm_state: bool = False,
1123 ):
1124 self.tests = tests
1125
1126 tmp_dir = Path(os.environ.get("TMPDIR", tempfile.gettempdir()))
1127 tmp_dir.mkdir(mode=0o700, exist_ok=True)
1128
1129 with rootlog.nested("start all VLans"):
1130 self.vlans = [VLan(nr, tmp_dir) for nr in vlans]
1131
1132 def cmd(scripts: List[str]) -> Iterator[NixStartScript]:
1133 for s in scripts:
1134 yield NixStartScript(s)
1135
1136 self.machines = [
1137 Machine(
1138 start_command=cmd,
1139 keep_vm_state=keep_vm_state,
1140 name=cmd.machine_name,
1141 tmp_dir=tmp_dir,
1142 )
1143 for cmd in cmd(start_scripts)
1144 ]
1145
1146 def __enter__(self) -> "Driver":
1147 return self
1148
1149 def __exit__(self, *_: Any) -> None:
1150 with rootlog.nested("cleanup"):
1151 for machine in self.machines:
1152 machine.release()
1153
1154 def subtest(self, name: str) -> Iterator[None]:
1155 """Group logs under a given test name"""
1156 with rootlog.nested(name):
1157 try:
1158 yield
1159 return True
1160 except Exception as e:
1161 rootlog.error(f'Test "{name}" failed with error: "{e}"')
1162 raise e
1163
1164 def test_symbols(self) -> Dict[str, Any]:
1165 @contextmanager
1166 def subtest(name: str) -> Iterator[None]:
1167 return self.subtest(name)
1168
1169 general_symbols = dict(
1170 start_all=self.start_all,
1171 test_script=self.test_script,
1172 machines=self.machines,
1173 vlans=self.vlans,
1174 driver=self,
1175 log=rootlog,
1176 os=os,
1177 create_machine=self.create_machine,
1178 subtest=subtest,
1179 run_tests=self.run_tests,
1180 join_all=self.join_all,
1181 retry=retry,
1182 serial_stdout_off=self.serial_stdout_off,
1183 serial_stdout_on=self.serial_stdout_on,
1184 Machine=Machine, # for typing
1185 )
1186 machine_symbols = {m.name: m for m in self.machines}
1187 # If there's exactly one machine, make it available under the name
1188 # "machine", even if it's not called that.
1189 if len(self.machines) == 1:
1190 (machine_symbols["machine"],) = self.machines
1191 vlan_symbols = {
1192 f"vlan{v.nr}": self.vlans[idx] for idx, v in enumerate(self.vlans)
1193 }
1194 print(
1195 "additionally exposed symbols:\n "
1196 + ", ".join(map(lambda m: m.name, self.machines))
1197 + ",\n "
1198 + ", ".join(map(lambda v: f"vlan{v.nr}", self.vlans))
1199 + ",\n "
1200 + ", ".join(list(general_symbols.keys()))
1201 )
1202 return {**general_symbols, **machine_symbols, **vlan_symbols}
1203
1204 def test_script(self) -> None:
1205 """Run the test script"""
1206 with rootlog.nested("run the VM test script"):
1207 symbols = self.test_symbols() # call eagerly
1208 exec(self.tests, symbols, None)
1209
1210 def run_tests(self) -> None:
1211 """Run the test script (for non-interactive test runs)"""
1212 self.test_script()
1213 # TODO: Collect coverage data
1214 for machine in self.machines:
1215 if machine.is_up():
1216 machine.execute("sync")
1217
1218 def start_all(self) -> None:
1219 """Start all machines"""
1220 with rootlog.nested("start all VMs"):
1221 for machine in self.machines:
1222 machine.start()
1223
1224 def join_all(self) -> None:
1225 """Wait for all machines to shut down"""
1226 with rootlog.nested("wait for all VMs to finish"):
1227 for machine in self.machines:
1228 machine.wait_for_shutdown()
1229
1230 def create_machine(self, args: Dict[str, Any]) -> Machine:
1231 rootlog.warning(
1232 "Using legacy create_machine(), please instantiate the"
1233 "Machine class directly, instead"
1234 )
1235 tmp_dir = Path(os.environ.get("TMPDIR", tempfile.gettempdir()))
1236 tmp_dir.mkdir(mode=0o700, exist_ok=True)
1237
1238 if args.get("startCommand"):
1239 start_command: str = args.get("startCommand", "")
1240 cmd = NixStartScript(start_command)
1241 name = args.get("name", cmd.machine_name)
1242 else:
1243 cmd = Machine.create_startcommand(args) # type: ignore
1244 name = args.get("name", "machine")
1245
1246 return Machine(
1247 tmp_dir=tmp_dir,
1248 start_command=cmd,
1249 name=name,
1250 keep_vm_state=args.get("keep_vm_state", False),
1251 allow_reboot=args.get("allow_reboot", False),
1252 )
1253
1254 def serial_stdout_on(self) -> None:
1255 rootlog._print_serial_logs = True
1256
1257 def serial_stdout_off(self) -> None:
1258 rootlog._print_serial_logs = False
1259
1260
1261class EnvDefault(argparse.Action):
1262 """An argpars Action that takes values from the specified
1263 environment variable as the flags default value.
1264 """
1265
1266 def __init__(self, envvar, required=False, default=None, nargs=None, **kwargs): # type: ignore
1267 if not default and envvar:
1268 if envvar in os.environ:
1269 if nargs is not None and (nargs.isdigit() or nargs in ["*", "+"]):
1270 default = os.environ[envvar].split()
1271 else:
1272 default = os.environ[envvar]
1273 kwargs["help"] = (
1274 kwargs["help"] + f" (default from environment: {default})"
1275 )
1276 if required and default:
1277 required = False
1278 super(EnvDefault, self).__init__(
1279 default=default, required=required, nargs=nargs, **kwargs
1280 )
1281
1282 def __call__(self, parser, namespace, values, option_string=None): # type: ignore
1283 setattr(namespace, self.dest, values)
1284
1285
1286if __name__ == "__main__":
1287 arg_parser = argparse.ArgumentParser(prog="nixos-test-driver")
1288 arg_parser.add_argument(
1289 "-K",
1290 "--keep-vm-state",
1291 help="re-use a VM state coming from a previous run",
1292 action="store_true",
1293 )
1294 arg_parser.add_argument(
1295 "-I",
1296 "--interactive",
1297 help="drop into a python repl and run the tests interactively",
1298 action="store_true",
1299 )
1300 arg_parser.add_argument(
1301 "--start-scripts",
1302 metavar="START-SCRIPT",
1303 action=EnvDefault,
1304 envvar="startScripts",
1305 nargs="*",
1306 help="start scripts for participating virtual machines",
1307 )
1308 arg_parser.add_argument(
1309 "--vlans",
1310 metavar="VLAN",
1311 action=EnvDefault,
1312 envvar="vlans",
1313 nargs="*",
1314 help="vlans to span by the driver",
1315 )
1316 arg_parser.add_argument(
1317 "testscript",
1318 action=EnvDefault,
1319 envvar="testScript",
1320 help="the test script to run",
1321 type=Path,
1322 )
1323
1324 args = arg_parser.parse_args()
1325
1326 if not args.keep_vm_state:
1327 rootlog.info("Machine state will be reset. To keep it, pass --keep-vm-state")
1328
1329 with Driver(
1330 args.start_scripts, args.vlans, args.testscript.read_text(), args.keep_vm_state
1331 ) as driver:
1332 if args.interactive:
1333 ptpython.repl.embed(driver.test_symbols(), {})
1334 else:
1335 tic = time.time()
1336 driver.run_tests()
1337 toc = time.time()
1338 rootlog.info(f"test script finished in {(toc-tic):.2f}s")