nixos/test-driver: improve error reporting and assertions (#390996)

Changed files
+118 -23
nixos
+3 -2
nixos/doc/manual/development/writing-nixos-tests.section.md
···
```py
machine.start()
machine.wait_for_unit("default.target")
-
if not "Linux" in machine.succeed("uname"):
-
raise Exception("Wrong OS")
```
The first line is technically unnecessary; machines are implicitly started
···
```py
start_all()
```
If the hostname of a node contains characters that can't be used in a
Python variable name, those characters will be replaced with
···
```py
machine.start()
machine.wait_for_unit("default.target")
+
t.assertIn("Linux", machine.succeed("uname"), "Wrong OS")
```
The first line is technically unnecessary; machines are implicitly started
···
```py
start_all()
```
+
+
Under the variable `t`, all assertions from [`unittest.TestCase`](https://docs.python.org/3/library/unittest.html) are available.
If the hostname of a node contains characters that can't be used in a
Python variable name, those characters will be replaced with
+1
nixos/lib/test-driver/default.nix
···
colorama
junit-xml
ptpython
]
++ extraPythonPackages python3Packages;
···
colorama
junit-xml
ptpython
+
ipython
]
++ extraPythonPackages python3Packages;
+1 -1
nixos/lib/test-driver/src/pyproject.toml
···
line-length = 88
lint.select = ["E", "F", "I", "U", "N"]
-
lint.ignore = ["E501"]
# xxx: we can import https://pypi.org/project/types-colorama/ here
[[tool.mypy.overrides]]
···
line-length = 88
lint.select = ["E", "F", "I", "U", "N"]
+
lint.ignore = ["E501", "N818"]
# xxx: we can import https://pypi.org/project/types-colorama/ here
[[tool.mypy.overrides]]
+4 -5
nixos/lib/test-driver/src/test_driver/__init__.py
···
import time
from pathlib import Path
-
import ptpython.repl
from test_driver.driver import Driver
from test_driver.logger import (
···
if args.interactive:
history_dir = os.getcwd()
history_path = os.path.join(history_dir, ".nixos-test-history")
-
ptpython.repl.embed(
-
driver.test_symbols(),
-
{},
history_filename=history_path,
-
)
else:
tic = time.time()
driver.run_tests()
···
import time
from pathlib import Path
+
import ptpython.ipython
from test_driver.driver import Driver
from test_driver.logger import (
···
if args.interactive:
history_dir = os.getcwd()
history_path = os.path.join(history_dir, ".nixos-test-history")
+
ptpython.ipython.embed(
+
user_ns=driver.test_symbols(),
history_filename=history_path,
+
) # type:ignore
else:
tic = time.time()
driver.run_tests()
+48 -2
nixos/lib/test-driver/src/test_driver/driver.py
···
import os
import re
import signal
import tempfile
import threading
from collections.abc import Callable, Iterator
from contextlib import AbstractContextManager, contextmanager
from pathlib import Path
from typing import Any
from test_driver.logger import AbstractLogger
from test_driver.machine import Machine, NixStartScript, retry
from test_driver.polling_condition import PollingCondition
from test_driver.vlan import VLan
SENTINEL = object()
def get_tmp_dir() -> Path:
···
try:
yield
except Exception as e:
-
self.logger.error(f'Test "{name}" failed with error: "{e}"')
raise e
def test_symbols(self) -> dict[str, Any]:
···
serial_stdout_on=self.serial_stdout_on,
polling_condition=self.polling_condition,
Machine=Machine, # for typing
)
machine_symbols = {pythonize_name(m.name): m for m in self.machines}
# If there's exactly one machine, make it available under the name
···
"""Run the test script"""
with self.logger.nested("run the VM test script"):
symbols = self.test_symbols() # call eagerly
-
exec(self.tests, symbols, None)
def run_tests(self) -> None:
"""Run the test script (for non-interactive test runs)"""
···
import os
import re
import signal
+
import sys
import tempfile
import threading
+
import traceback
from collections.abc import Callable, Iterator
from contextlib import AbstractContextManager, contextmanager
from pathlib import Path
from typing import Any
+
from unittest import TestCase
+
from test_driver.errors import MachineError, RequestedAssertionFailed
from test_driver.logger import AbstractLogger
from test_driver.machine import Machine, NixStartScript, retry
from test_driver.polling_condition import PollingCondition
from test_driver.vlan import VLan
SENTINEL = object()
+
+
+
class AssertionTester(TestCase):
+
"""
+
Subclass of `unittest.TestCase` which is used in the
+
`testScript` to perform assertions.
+
+
It throws a custom exception whose parent class
+
gets special treatment in the logs.
+
"""
+
+
failureException = RequestedAssertionFailed
def get_tmp_dir() -> Path:
···
try:
yield
except Exception as e:
+
self.logger.log_test_error(f'Test "{name}" failed with error: "{e}"')
raise e
def test_symbols(self) -> dict[str, Any]:
···
serial_stdout_on=self.serial_stdout_on,
polling_condition=self.polling_condition,
Machine=Machine, # for typing
+
t=AssertionTester(),
)
machine_symbols = {pythonize_name(m.name): m for m in self.machines}
# If there's exactly one machine, make it available under the name
···
"""Run the test script"""
with self.logger.nested("run the VM test script"):
symbols = self.test_symbols() # call eagerly
+
try:
+
exec(self.tests, symbols, None)
+
except MachineError:
+
for line in traceback.format_exc().splitlines():
+
self.logger.log_test_error(line)
+
sys.exit(1)
+
except RequestedAssertionFailed:
+
exc_type, exc, tb = sys.exc_info()
+
# We manually print the stack frames, keeping only the ones from the test script
+
# (note: because the script is not a real file, the frame filename is `<string>`)
+
filtered = [
+
frame
+
for frame in traceback.extract_tb(tb)
+
if frame.filename == "<string>"
+
]
+
+
self.logger.log_test_error("Traceback (most recent call last):")
+
+
code = self.tests.splitlines()
+
for frame, line in zip(filtered, traceback.format_list(filtered)):
+
self.logger.log_test_error(line.rstrip())
+
if lineno := frame.lineno:
+
self.logger.log_test_error(f" {code[lineno - 1].strip()}")
+
+
self.logger.log_test_error("") # blank line for readability
+
exc_prefix = exc_type.__name__ if exc_type is not None else "Error"
+
for line in f"{exc_prefix}: {exc}".splitlines():
+
self.logger.log_test_error(line)
+
+
sys.exit(1)
def run_tests(self) -> None:
"""Run the test script (for non-interactive test runs)"""
+20
nixos/lib/test-driver/src/test_driver/errors.py
···
···
+
class MachineError(Exception):
+
"""
+
Exception that indicates an error that is NOT the user's fault,
+
i.e. something went wrong without the test being necessarily invalid,
+
such as failing OCR.
+
+
To make it easier to spot, this exception (and its subclasses)
+
get a `!!!` prefix in the log output.
+
"""
+
+
+
class RequestedAssertionFailed(AssertionError):
+
"""
+
Special assertion that gets thrown on an assertion error,
+
e.g. a failing `t.assertEqual(...)` or `machine.succeed(...)`.
+
+
This gets special treatment in error reporting: i.e. it gets
+
`!!!` as prefix just as `MachineError`, but only stack frames coming
+
from `testScript` will show up in logs.
+
"""
+20 -1
nixos/lib/test-driver/src/test_driver/logger.py
···
pass
@abstractmethod
def log_serial(self, message: str, machine: str) -> None:
pass
···
self.tests[self.currentSubtest].stderr += args[0] + os.linesep
self.tests[self.currentSubtest].failure = True
def log_serial(self, message: str, machine: str) -> None:
if not self._print_serial_logs:
return
···
def warning(self, *args, **kwargs) -> None: # type: ignore
for logger in self.logger_list:
logger.warning(*args, **kwargs)
def error(self, *args, **kwargs) -> None: # type: ignore
for logger in self.logger_list:
···
tic = time.time()
yield
toc = time.time()
-
self.log(f"(finished: {message}, in {toc - tic:.2f} seconds)")
def info(self, *args, **kwargs) -> None: # type: ignore
self.log(*args, **kwargs)
···
return
self._eprint(Style.DIM + f"{machine} # {message}" + Style.RESET_ALL)
class XMLLogger(AbstractLogger):
···
self.log(*args, **kwargs)
def error(self, *args, **kwargs) -> None: # type: ignore
self.log(*args, **kwargs)
def log(self, message: str, attributes: dict[str, str] = {}) -> None:
···
pass
@abstractmethod
+
def log_test_error(self, *args, **kwargs) -> None: # type:ignore
+
pass
+
+
@abstractmethod
def log_serial(self, message: str, machine: str) -> None:
pass
···
self.tests[self.currentSubtest].stderr += args[0] + os.linesep
self.tests[self.currentSubtest].failure = True
+
def log_test_error(self, *args, **kwargs) -> None: # type: ignore
+
self.error(*args, **kwargs)
+
def log_serial(self, message: str, machine: str) -> None:
if not self._print_serial_logs:
return
···
def warning(self, *args, **kwargs) -> None: # type: ignore
for logger in self.logger_list:
logger.warning(*args, **kwargs)
+
+
def log_test_error(self, *args, **kwargs) -> None: # type: ignore
+
for logger in self.logger_list:
+
logger.log_test_error(*args, **kwargs)
def error(self, *args, **kwargs) -> None: # type: ignore
for logger in self.logger_list:
···
tic = time.time()
yield
toc = time.time()
+
self.log(f"(finished: {message}, in {toc - tic:.2f} seconds)", attributes)
def info(self, *args, **kwargs) -> None: # type: ignore
self.log(*args, **kwargs)
···
return
self._eprint(Style.DIM + f"{machine} # {message}" + Style.RESET_ALL)
+
+
def log_test_error(self, *args, **kwargs) -> None: # type: ignore
+
prefix = Fore.RED + "!!! " + Style.RESET_ALL
+
# NOTE: using `warning` instead of `error` to ensure it does not exit after printing the first log
+
self.warning(f"{prefix}{args[0]}", *args[1:], **kwargs)
class XMLLogger(AbstractLogger):
···
self.log(*args, **kwargs)
def error(self, *args, **kwargs) -> None: # type: ignore
+
self.log(*args, **kwargs)
+
+
def log_test_error(self, *args, **kwargs) -> None: # type: ignore
self.log(*args, **kwargs)
def log(self, message: str, attributes: dict[str, str] = {}) -> None:
+19 -12
nixos/lib/test-driver/src/test_driver/machine.py
···
from queue import Queue
from typing import Any
from test_driver.logger import AbstractLogger
from .qmp import QMPSession
···
)
if ret.returncode != 0:
-
raise Exception(
f"Image processing failed with exit code {ret.returncode}, stdout: {ret.stdout.decode()}, stderr: {ret.stderr.decode()}"
)
···
screenshot_path: str, model_ids: Iterable[int]
) -> list[str]:
if shutil.which("tesseract") is None:
-
raise Exception("OCR requested but enableOCR is false")
processed_image = _preprocess_screenshot(screenshot_path, negate=False)
processed_negative = _preprocess_screenshot(screenshot_path, negate=True)
···
capture_output=True,
)
if ret.returncode != 0:
-
raise Exception(f"OCR failed with exit code {ret.returncode}")
model_results.append(ret.stdout.decode("utf-8"))
return model_results
···
time.sleep(1)
if not fn(True):
-
raise Exception(f"action timed out after {timeout} seconds")
class StartCommand:
···
def check_active(_last_try: bool) -> bool:
state = self.get_unit_property(unit, "ActiveState", user)
if state == "failed":
-
raise Exception(f'unit "{unit}" reached state "{state}"')
if state == "inactive":
status, jobs = self.systemctl("list-jobs --full 2>&1", user)
if "No jobs" in jobs:
info = self.get_unit_info(unit, user)
if info["ActiveState"] == state:
-
raise Exception(
f'unit "{unit}" is inactive and there are no pending jobs'
)
···
def get_unit_info(self, unit: str, user: str | None = None) -> dict[str, str]:
status, lines = self.systemctl(f'--no-pager show "{unit}"', user)
if status != 0:
-
raise Exception(
f'retrieving systemctl info for unit "{unit}"'
+ ("" if user is None else f' under user "{user}"')
+ f" failed with exit code {status}"
···
user,
)
if status != 0:
-
raise Exception(
f'retrieving systemctl property "{property}" for unit "{unit}"'
+ ("" if user is None else f' under user "{user}"')
+ f" failed with exit code {status}"
···
info = self.get_unit_info(unit)
state = info["ActiveState"]
if state != require_state:
-
raise Exception(
f"Expected unit '{unit}' to to be in state "
f"'{require_state}' but it is in state '{state}'"
)
···
(status, out) = self.execute(command, timeout=timeout)
if status != 0:
self.log(f"output: {out}")
-
raise Exception(f"command `{command}` failed (exit code {status})")
output += out
return output
···
with self.nested(f"must fail: {command}"):
(status, out) = self.execute(command, timeout=timeout)
if status == 0:
-
raise Exception(f"command `{command}` unexpectedly succeeded")
output += out
return output
···
ret = subprocess.run(f"pnmtopng '{tmp}' > '{filename}'", shell=True)
os.unlink(tmp)
if ret.returncode != 0:
-
raise Exception("Cannot convert screenshot")
def copy_from_host_via_shell(self, source: str, target: str) -> None:
"""Copy a file from the host into the guest by piping it over the
···
from queue import Queue
from typing import Any
+
from test_driver.errors import MachineError, RequestedAssertionFailed
from test_driver.logger import AbstractLogger
from .qmp import QMPSession
···
)
if ret.returncode != 0:
+
raise MachineError(
f"Image processing failed with exit code {ret.returncode}, stdout: {ret.stdout.decode()}, stderr: {ret.stderr.decode()}"
)
···
screenshot_path: str, model_ids: Iterable[int]
) -> list[str]:
if shutil.which("tesseract") is None:
+
raise MachineError("OCR requested but enableOCR is false")
processed_image = _preprocess_screenshot(screenshot_path, negate=False)
processed_negative = _preprocess_screenshot(screenshot_path, negate=True)
···
capture_output=True,
)
if ret.returncode != 0:
+
raise MachineError(f"OCR failed with exit code {ret.returncode}")
model_results.append(ret.stdout.decode("utf-8"))
return model_results
···
time.sleep(1)
if not fn(True):
+
raise RequestedAssertionFailed(
+
f"action timed out after {timeout} tries with one-second pause in-between"
+
)
class StartCommand:
···
def check_active(_last_try: bool) -> bool:
state = self.get_unit_property(unit, "ActiveState", user)
if state == "failed":
+
raise RequestedAssertionFailed(f'unit "{unit}" reached state "{state}"')
if state == "inactive":
status, jobs = self.systemctl("list-jobs --full 2>&1", user)
if "No jobs" in jobs:
info = self.get_unit_info(unit, user)
if info["ActiveState"] == state:
+
raise RequestedAssertionFailed(
f'unit "{unit}" is inactive and there are no pending jobs'
)
···
def get_unit_info(self, unit: str, user: str | None = None) -> dict[str, str]:
status, lines = self.systemctl(f'--no-pager show "{unit}"', user)
if status != 0:
+
raise RequestedAssertionFailed(
f'retrieving systemctl info for unit "{unit}"'
+ ("" if user is None else f' under user "{user}"')
+ f" failed with exit code {status}"
···
user,
)
if status != 0:
+
raise RequestedAssertionFailed(
f'retrieving systemctl property "{property}" for unit "{unit}"'
+ ("" if user is None else f' under user "{user}"')
+ f" failed with exit code {status}"
···
info = self.get_unit_info(unit)
state = info["ActiveState"]
if state != require_state:
+
raise RequestedAssertionFailed(
f"Expected unit '{unit}' to to be in state "
f"'{require_state}' but it is in state '{state}'"
)
···
(status, out) = self.execute(command, timeout=timeout)
if status != 0:
self.log(f"output: {out}")
+
raise RequestedAssertionFailed(
+
f"command `{command}` failed (exit code {status})"
+
)
output += out
return output
···
with self.nested(f"must fail: {command}"):
(status, out) = self.execute(command, timeout=timeout)
if status == 0:
+
raise RequestedAssertionFailed(
+
f"command `{command}` unexpectedly succeeded"
+
)
output += out
return output
···
ret = subprocess.run(f"pnmtopng '{tmp}' > '{filename}'", shell=True)
os.unlink(tmp)
if ret.returncode != 0:
+
raise MachineError("Cannot convert screenshot")
def copy_from_host_via_shell(self, source: str, target: str) -> None:
"""Copy a file from the host into the guest by piping it over the
+2
nixos/lib/test-script-prepend.py
···
from typing import Callable, Iterator, ContextManager, Optional, List, Dict, Any, Union
from typing_extensions import Protocol
from pathlib import Path
class RetryProtocol(Protocol):
···
serial_stdout_off: Callable[[], None]
serial_stdout_on: Callable[[], None]
polling_condition: PollingConditionProtocol
···
from typing import Callable, Iterator, ContextManager, Optional, List, Dict, Any, Union
from typing_extensions import Protocol
from pathlib import Path
+
from unittest import TestCase
class RetryProtocol(Protocol):
···
serial_stdout_off: Callable[[], None]
serial_stdout_on: Callable[[], None]
polling_condition: PollingConditionProtocol
+
t: TestCase