Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions nixos/lib/test-driver/test_driver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,22 @@ def __call__(self, parser, namespace, values, option_string=None): # type: igno
setattr(namespace, self.dest, values)


def writeable_dir(arg: str) -> Path:
"""Raises an ArgumentTypeError if the given argument isn't a writeable directory
Note: We want to fail as early as possible if a directory isn't writeable,
since an executed nixos-test could fail (very late) because of the test-driver
writing in a directory without proper permissions.
"""
path = Path(arg)
if not path.is_dir():
raise argparse.ArgumentTypeError("{0} is not a directory".format(path))
if not os.access(path, os.W_OK):
raise argparse.ArgumentTypeError(
"{0} is not a writeable directory".format(path)
)
return path


def main() -> None:
arg_parser = argparse.ArgumentParser(prog="nixos-test-driver")
arg_parser.add_argument(
Expand Down Expand Up @@ -63,6 +79,14 @@ def main() -> None:
nargs="*",
help="vlans to span by the driver",
)
arg_parser.add_argument(
"-o",
"--output_directory",
help="""The path to the directory where outputs copied from the VM will be placed.
By e.g. Machine.copy_from_vm or Machine.screenshot""",
default=Path.cwd(),
type=writeable_dir,
)
arg_parser.add_argument(
"testscript",
action=EnvDefault,
Expand All @@ -77,7 +101,11 @@ def main() -> None:
rootlog.info("Machine state will be reset. To keep it, pass --keep-vm-state")

with Driver(
args.start_scripts, args.vlans, args.testscript.read_text(), args.keep_vm_state
args.start_scripts,
args.vlans,
args.testscript.read_text(),
args.output_directory.resolve(),
args.keep_vm_state,
) as driver:
if args.interactive:
ptpython.repl.embed(driver.test_symbols(), {})
Expand All @@ -94,7 +122,7 @@ def generate_driver_symbols() -> None:
in user's test scripts. That list is then used by pyflakes to lint those
scripts.
"""
d = Driver([], [], "")
d = Driver([], [], "", Path())
test_symbols = d.test_symbols()
with open("driver-symbols", "w") as fp:
fp.write(",".join(test_symbols.keys()))
33 changes: 29 additions & 4 deletions nixos/lib/test-driver/test_driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,28 @@
from test_driver.polling_condition import PollingCondition


def get_tmp_dir() -> Path:
"""Returns a temporary directory that is defined by TMPDIR, TEMP, TMP or CWD
Raises an exception in case the retrieved temporary directory is not writeable
See https://docs.python.org/3/library/tempfile.html#tempfile.gettempdir
"""
tmp_dir = Path(tempfile.gettempdir())
tmp_dir.mkdir(mode=0o700, exist_ok=True)
if not tmp_dir.is_dir():
raise NotADirectoryError(
"The directory defined by TMPDIR, TEMP, TMP or CWD: {0} is not a directory".format(
tmp_dir
)
)
if not os.access(tmp_dir, os.W_OK):
raise PermissionError(
"The directory defined by TMPDIR, TEMP, TMP, or CWD: {0} is not writeable".format(
tmp_dir
)
)
return tmp_dir


class Driver:
"""A handle to the driver that sets up the environment
and runs the tests"""
Expand All @@ -24,12 +46,13 @@ def __init__(
start_scripts: List[str],
vlans: List[int],
tests: str,
out_dir: Path,
keep_vm_state: bool = False,
):
self.tests = tests
self.out_dir = out_dir

tmp_dir = Path(os.environ.get("TMPDIR", tempfile.gettempdir()))
tmp_dir.mkdir(mode=0o700, exist_ok=True)
tmp_dir = get_tmp_dir()

with rootlog.nested("start all VLans"):
self.vlans = [VLan(nr, tmp_dir) for nr in vlans]
Expand All @@ -47,6 +70,7 @@ def cmd(scripts: List[str]) -> Iterator[NixStartScript]:
name=cmd.machine_name,
tmp_dir=tmp_dir,
callbacks=[self.check_polling_conditions],
out_dir=self.out_dir,
)
for cmd in cmd(start_scripts)
]
Expand Down Expand Up @@ -141,8 +165,8 @@ def create_machine(self, args: Dict[str, Any]) -> Machine:
"Using legacy create_machine(), please instantiate the"
"Machine class directly, instead"
)
tmp_dir = Path(os.environ.get("TMPDIR", tempfile.gettempdir()))
tmp_dir.mkdir(mode=0o700, exist_ok=True)

tmp_dir = get_tmp_dir()

if args.get("startCommand"):
start_command: str = args.get("startCommand", "")
Expand All @@ -154,6 +178,7 @@ def create_machine(self, args: Dict[str, Any]) -> Machine:

return Machine(
tmp_dir=tmp_dir,
out_dir=self.out_dir,
start_command=cmd,
name=name,
keep_vm_state=args.get("keep_vm_state", False),
Expand Down
9 changes: 5 additions & 4 deletions nixos/lib/test-driver/test_driver/machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ class Machine:
the machine lifecycle with the help of a start script / command."""

name: str
out_dir: Path
tmp_dir: Path
shared_dir: Path
state_dir: Path
Expand Down Expand Up @@ -325,13 +326,15 @@ def __repr__(self) -> str:

def __init__(
self,
out_dir: Path,
tmp_dir: Path,
start_command: StartCommand,
name: str = "machine",
keep_vm_state: bool = False,
allow_reboot: bool = False,
callbacks: Optional[List[Callable]] = None,
) -> None:
self.out_dir = out_dir
self.tmp_dir = tmp_dir
self.keep_vm_state = keep_vm_state
self.allow_reboot = allow_reboot
Expand Down Expand Up @@ -702,10 +705,9 @@ def connect(self) -> None:
self.connected = True

def screenshot(self, filename: str) -> None:
out_dir = os.environ.get("out", os.getcwd())
word_pattern = re.compile(r"^\w+$")
if word_pattern.match(filename):
filename = os.path.join(out_dir, "{}.png".format(filename))
filename = os.path.join(self.out_dir, "{}.png".format(filename))
tmp = "{}.ppm".format(filename)

with self.nested(
Expand Down Expand Up @@ -756,7 +758,6 @@ def copy_from_vm(self, source: str, target_dir: str = "") -> None:
all the VMs (using a temporary directory).
"""
# Compute the source, target, and intermediate shared file names
out_dir = Path(os.environ.get("out", os.getcwd()))
vm_src = Path(source)
with tempfile.TemporaryDirectory(dir=self.shared_dir) as shared_td:
shared_temp = Path(shared_td)
Expand All @@ -766,7 +767,7 @@ def copy_from_vm(self, source: str, target_dir: str = "") -> None:
# Copy the file to the shared directory inside VM
self.succeed(make_command(["mkdir", "-p", vm_shared_temp]))
self.succeed(make_command(["cp", "-r", vm_src, vm_intermediate]))
abs_target = out_dir / target_dir / vm_src.name
abs_target = self.out_dir / target_dir / vm_src.name
abs_target.parent.mkdir(exist_ok=True, parents=True)
# Copy the file from the shared directory outside VM
if intermediate.is_dir():
Expand Down