1import argparse
2import os
3import time
4from pathlib import Path
5
6import ptpython.ipython
7
8from test_driver.debug import Debug, DebugAbstract, DebugNop
9from test_driver.driver import Driver
10from test_driver.logger import (
11 CompositeLogger,
12 JunitXMLLogger,
13 TerminalLogger,
14 XMLLogger,
15)
16
17
18class EnvDefault(argparse.Action):
19 """An argpars Action that takes values from the specified
20 environment variable as the flags default value.
21 """
22
23 def __init__(self, envvar, required=False, default=None, nargs=None, **kwargs): # type: ignore
24 if not default and envvar:
25 if envvar in os.environ:
26 if nargs is not None and (nargs.isdigit() or nargs in ["*", "+"]):
27 default = os.environ[envvar].split()
28 else:
29 default = os.environ[envvar]
30 kwargs["help"] = (
31 kwargs["help"] + f" (default from environment: {default})"
32 )
33 if required and default:
34 required = False
35 super().__init__(default=default, required=required, nargs=nargs, **kwargs)
36
37 def __call__(self, parser, namespace, values, option_string=None): # type: ignore
38 setattr(namespace, self.dest, values)
39
40
41def writeable_dir(arg: str) -> Path:
42 """Raises an ArgumentTypeError if the given argument isn't a writeable directory
43 Note: We want to fail as early as possible if a directory isn't writeable,
44 since an executed nixos-test could fail (very late) because of the test-driver
45 writing in a directory without proper permissions.
46 """
47 path = Path(arg)
48 if not path.is_dir():
49 raise argparse.ArgumentTypeError(f"{path} is not a directory")
50 if not os.access(path, os.W_OK):
51 raise argparse.ArgumentTypeError(f"{path} is not a writeable directory")
52 return path
53
54
55def main() -> None:
56 arg_parser = argparse.ArgumentParser(prog="nixos-test-driver")
57 arg_parser.add_argument(
58 "-K",
59 "--keep-vm-state",
60 help="re-use a VM state coming from a previous run",
61 action="store_true",
62 )
63 arg_parser.add_argument(
64 "-I",
65 "--interactive",
66 help="drop into a python repl and run the tests interactively",
67 action=argparse.BooleanOptionalAction,
68 )
69 arg_parser.add_argument(
70 "--debug-hook-attach",
71 help="Enable interactive debugging breakpoints for sandboxed runs",
72 )
73 arg_parser.add_argument(
74 "--start-scripts",
75 metavar="START-SCRIPT",
76 action=EnvDefault,
77 envvar="startScripts",
78 nargs="*",
79 help="start scripts for participating virtual machines",
80 )
81 arg_parser.add_argument(
82 "--vlans",
83 metavar="VLAN",
84 action=EnvDefault,
85 envvar="vlans",
86 nargs="*",
87 help="vlans to span by the driver",
88 )
89 arg_parser.add_argument(
90 "--global-timeout",
91 type=int,
92 metavar="GLOBAL_TIMEOUT",
93 action=EnvDefault,
94 envvar="globalTimeout",
95 help="Timeout in seconds for the whole test",
96 )
97 arg_parser.add_argument(
98 "-o",
99 "--output_directory",
100 help="""The path to the directory where outputs copied from the VM will be placed.
101 By e.g. Machine.copy_from_vm or Machine.screenshot""",
102 default=Path.cwd(),
103 type=writeable_dir,
104 )
105 arg_parser.add_argument(
106 "--junit-xml",
107 help="Enable JunitXML report generation to the given path",
108 type=Path,
109 )
110 arg_parser.add_argument(
111 "testscript",
112 action=EnvDefault,
113 envvar="testScript",
114 help="the test script to run",
115 type=Path,
116 )
117 arg_parser.add_argument(
118 "--dump-vsocks",
119 help="indicates that the interactive SSH backdoor is active and dumps information about it on start",
120 type=int,
121 )
122
123 args = arg_parser.parse_args()
124
125 output_directory = args.output_directory.resolve()
126 logger = CompositeLogger([TerminalLogger()])
127
128 if "LOGFILE" in os.environ.keys():
129 logger.add_logger(XMLLogger(os.environ["LOGFILE"]))
130
131 if args.junit_xml:
132 logger.add_logger(JunitXMLLogger(output_directory / args.junit_xml))
133
134 if not args.keep_vm_state:
135 logger.info("Machine state will be reset. To keep it, pass --keep-vm-state")
136
137 debugger: DebugAbstract = DebugNop()
138 if args.debug_hook_attach is not None:
139 debugger = Debug(logger, args.debug_hook_attach)
140
141 with Driver(
142 args.start_scripts,
143 args.vlans,
144 args.testscript.read_text(),
145 output_directory,
146 logger,
147 args.keep_vm_state,
148 args.global_timeout,
149 debug=debugger,
150 ) as driver:
151 if offset := args.dump_vsocks:
152 driver.dump_machine_ssh(offset)
153 if args.interactive:
154 history_dir = os.getcwd()
155 history_path = os.path.join(history_dir, ".nixos-test-history")
156 ptpython.ipython.embed(
157 user_ns=driver.test_symbols(),
158 history_filename=history_path,
159 ) # type:ignore
160 else:
161 tic = time.time()
162 driver.run_tests()
163 toc = time.time()
164 logger.info(f"test script finished in {(toc - tic):.2f}s")
165
166
167def generate_driver_symbols() -> None:
168 """
169 This generates a file with symbols of the test-driver code that can be used
170 in user's test scripts. That list is then used by pyflakes to lint those
171 scripts.
172 """
173 d = Driver([], [], "", Path(), CompositeLogger([]))
174 test_symbols = d.test_symbols()
175 with open("driver-symbols", "w") as fp:
176 fp.write(",".join(test_symbols.keys()))