diff --git a/distributed/cli/dask_scheduler.py b/distributed/cli/dask_scheduler.py index d65e13d60eb..8d224e5195c 100755 --- a/distributed/cli/dask_scheduler.py +++ b/distributed/cli/dask_scheduler.py @@ -16,7 +16,7 @@ from distributed.utils import get_ip_interface, ignoring from distributed.cli.utils import (check_python_3, install_signal_handlers, uri_from_host_port) -from distributed.preloading import preload_modules +from distributed.preloading import preload_modules, validate_preload_argv from distributed.proctitle import (enable_proctitle_on_children, enable_proctitle_on_current) @@ -26,7 +26,7 @@ pem_file_option_type = click.Path(exists=True, resolve_path=True) -@click.command() +@click.command(context_settings=dict(ignore_unknown_options=True)) @click.option('--host', type=str, default='', help="URI, IP or hostname of this server") @click.option('--port', type=int, default=None, help="Serving port") @@ -58,11 +58,14 @@ "cluster is on a shared network file system.") @click.option('--local-directory', default='', type=str, help="Directory to place scheduler files") -@click.option('--preload', type=str, multiple=True, - help='Module that should be loaded by each worker process like "foo.bar" or "/path/to/foo.py"') +@click.option('--preload', type=str, multiple=True, is_eager=True, + help='Module that should be loaded by each worker process ' + 'like "foo.bar" or "/path/to/foo.py"') +@click.argument('preload_argv', nargs=-1, + type=click.UNPROCESSED, callback=validate_preload_argv) def main(host, port, bokeh_port, show, _bokeh, bokeh_whitelist, bokeh_prefix, - use_xheaders, pid_file, scheduler_file, interface, - local_directory, preload, tls_ca_file, tls_cert, tls_key): + use_xheaders, pid_file, scheduler_file, interface, + local_directory, preload, preload_argv, tls_ca_file, tls_cert, tls_key): enable_proctitle_on_current() enable_proctitle_on_children() @@ -119,7 +122,7 @@ def del_pid_file(): scheduler_file=scheduler_file, security=sec) scheduler.start(addr) - preload_modules(preload, parameter=scheduler, file_dir=local_directory) + preload_modules(preload, parameter=scheduler, file_dir=local_directory, argv=preload_argv) logger.info('Local Directory: %26s', local_directory) logger.info('-' * 47) diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index 5a06d78c219..592ec623d78 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -14,6 +14,7 @@ from distributed.cli.utils import (check_python_3, uri_from_host_port, install_signal_handlers) from distributed.comm import get_address_host_port +from distributed.preloading import validate_preload_argv from distributed.proctitle import (enable_proctitle_on_children, enable_proctitle_on_current) @@ -27,7 +28,7 @@ pem_file_option_type = click.Path(exists=True, resolve_path=True) -@click.command() +@click.command(context_settings=dict(ignore_unknown_options=True)) @click.argument('scheduler', type=str, required=False) @click.option('--tls-ca-file', type=pem_file_option_type, default=None, help="CA cert(s) file for TLS (in PEM format)") @@ -89,14 +90,16 @@ help="Seconds to wait for a scheduler before closing") @click.option('--bokeh-prefix', type=str, default=None, help="Prefix for the bokeh app") -@click.option('--preload', type=str, multiple=True, +@click.option('--preload', type=str, multiple=True, is_eager=True, help='Module that should be loaded by each worker process ' 'like "foo.bar" or "/path/to/foo.py"') +@click.argument('preload_argv', nargs=-1, + type=click.UNPROCESSED, callback=validate_preload_argv) def main(scheduler, host, worker_port, listen_address, contact_address, nanny_port, nthreads, nprocs, nanny, name, memory_limit, pid_file, reconnect, resources, bokeh, bokeh_port, local_directory, scheduler_file, interface, - death_timeout, preload, bokeh_prefix, tls_ca_file, + death_timeout, preload, preload_argv, bokeh_prefix, tls_ca_file, tls_cert, tls_key): enable_proctitle_on_current() enable_proctitle_on_children() @@ -212,7 +215,8 @@ def del_pid_file(): services=services, loop=loop, resources=resources, memory_limit=memory_limit, reconnect=reconnect, local_dir=local_directory, death_timeout=death_timeout, - preload=preload, security=sec, contact_address=contact_address, + preload=preload, preload_argv=preload_argv, + security=sec, contact_address=contact_address, name=name if nprocs == 1 else name + '-' + str(i), **kwargs) for i in range(nprocs)] diff --git a/distributed/cli/tests/test_dask_scheduler.py b/distributed/cli/tests/test_dask_scheduler.py index 874c0343620..88507429cdd 100644 --- a/distributed/cli/tests/test_dask_scheduler.py +++ b/distributed/cli/tests/test_dask_scheduler.py @@ -296,3 +296,64 @@ def check_scheduler(): c.scheduler.address finally: shutil.rmtree(tmpdir) + + +PRELOAD_COMMAND_TEXT = """ +import click +_config = {} + +@click.command() +@click.option("--passthrough", type=str, default="default") +def dask_setup(scheduler, passthrough): + _config["passthrough"] = passthrough + +def get_passthrough(): + return _config["passthrough"] +""" + + +def test_preload_command(loop): + + def check_passthrough(): + import passthrough_info + return passthrough_info.get_passthrough() + + tmpdir = tempfile.mkdtemp() + try: + path = os.path.join(tmpdir, 'passthrough_info.py') + with open(path, 'w') as f: + f.write(PRELOAD_COMMAND_TEXT) + + with tmpfile() as fn: + print(fn) + with popen(['dask-scheduler', '--scheduler-file', fn, + '--preload', path, "--passthrough", "foobar"]): + with Client(scheduler_file=fn, loop=loop) as c: + assert c.run_on_scheduler(check_passthrough) == \ + "foobar" + finally: + shutil.rmtree(tmpdir) + + +def test_preload_command_default(loop): + + def check_passthrough(): + import passthrough_info + return passthrough_info.get_passthrough() + + tmpdir = tempfile.mkdtemp() + try: + path = os.path.join(tmpdir, 'passthrough_info.py') + with open(path, 'w') as f: + f.write(PRELOAD_COMMAND_TEXT) + + with tmpfile() as fn2: + print(fn2) + with popen(['dask-scheduler', '--scheduler-file', fn2, + '--preload', path], stdout=sys.stdout, stderr=sys.stderr): + with Client(scheduler_file=fn2, loop=loop) as c: + assert c.run_on_scheduler(check_passthrough) == \ + "default" + + finally: + shutil.rmtree(tmpdir) diff --git a/distributed/nanny.py b/distributed/nanny.py index c0d6c2f51fe..166281badb4 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -42,7 +42,7 @@ def __init__(self, scheduler_ip=None, scheduler_port=None, ncores=None, loop=None, local_dir=None, services=None, name=None, memory_limit='auto', reconnect=True, validate=False, quiet=False, resources=None, silence_logs=None, - death_timeout=None, preload=(), security=None, + death_timeout=None, preload=(), preload_argv=[], security=None, contact_address=None, listen_address=None, **kwargs): if scheduler_file: cfg = json_load_robust(scheduler_file) @@ -60,6 +60,8 @@ def __init__(self, scheduler_ip=None, scheduler_port=None, self.resources = resources self.death_timeout = death_timeout self.preload = preload + self.preload_argv = preload_argv + self.contact_address = contact_address self.memory_terminate_fraction = config.get('worker-memory-terminate', 0.95) @@ -203,6 +205,7 @@ def instantiate(self, comm=None): silence_logs=self.silence_logs, death_timeout=self.death_timeout, preload=self.preload, + preload_argv=self.preload_argv, security=self.security, contact_address=self.contact_address), worker_start_args=(start_arg,), diff --git a/distributed/preloading.py b/distributed/preloading.py index e8745db8d95..00fa4eaeae2 100644 --- a/distributed/preloading.py +++ b/distributed/preloading.py @@ -3,25 +3,79 @@ import os import shutil import sys +import filecmp from importlib import import_module +import click + from .utils import import_file logger = logging.getLogger(__name__) -def preload_modules(names, parameter=None, file_dir=None): - """ Imports modules, handles `dask_setup` and `dask_teardown` functions +def validate_preload_argv(ctx, param, value): + """Click option callback providing validation of preload subcommand arguments.""" + if not value and not ctx.params.get("preload", None): + # No preload argv provided and no preload modules specified. + return value + + if value and not ctx.params.get("preload", None): + # Report a usage error matching standard click error conventions. + unexpected_args = [v for v in value if v.startswith("-")] + for a in unexpected_args: + raise click.NoSuchOption(a) + raise click.UsageError( + "Got unexpected extra argument%s: (%s)" % + ("s" if len(value) > 1 else "", " ".join(value)) + ) + + preload_modules = _import_modules(ctx.params.get("preload")) + + preload_commands = [ + m["dask_setup"] for m in preload_modules.values() + if isinstance(m["dask_setup"], click.Command) + ] + + if len(preload_commands) > 1: + raise click.UsageError( + "Multiple --preload modules with click-configurable setup: %s" % + list(preload_modules.keys())) + + if value and not preload_commands: + raise click.UsageError( + "Unknown argument specified: %r Was click-configurable --preload target provided?") + if not preload_commands: + return value + else: + preload_command = preload_commands[0] + + ctx = click.Context(preload_command, allow_extra_args=False) + preload_command.parse_args(ctx, list(value)) + + return value + + +def _import_modules(names, file_dir=None): + """ Imports modules and extracts preload interface functions. + + Imports modules specified by names and extracts 'dask_setup' + and 'dask_teardown' if present. + Parameters ---------- names: list of strings Module names or file paths - parameter: object - Parameter passed to `dask_setup` and `dask_teardown` file_dir: string Path of a directory where files should be copied + + Returns + ------- + Nest dict of names to extracted module interface components if present + in imported module. """ + result_modules = {} + for name in names: # import if name.endswith(".py"): @@ -30,7 +84,8 @@ def preload_modules(names, parameter=None, file_dir=None): basename = os.path.basename(name) copy_dst = os.path.join(file_dir, basename) if os.path.exists(copy_dst): - logger.error("File name collision: %s", basename) + if not filecmp.cmp(name, copy_dst): + logger.error("File name collision: %s", basename) shutil.copy(name, copy_dst) module = import_file(copy_dst)[0] else: @@ -42,10 +97,41 @@ def preload_modules(names, parameter=None, file_dir=None): import_module(name) module = sys.modules[name] - # handle special functions - dask_setup = getattr(module, 'dask_setup', None) - dask_teardown = getattr(module, 'dask_teardown', None) - if dask_setup is not None: - dask_setup(parameter) - if dask_teardown is not None: - atexit.register(dask_teardown, parameter) + result_modules[name] = { + attrname : getattr(module, attrname, None) + for attrname in ("dask_setup", "dask_teardown") + } + + return result_modules + + +def preload_modules(names, parameter=None, file_dir=None, argv=None): + """ Imports modules, handles `dask_setup` and `dask_teardown`. + + Parameters + ---------- + names: list of strings + Module names or file paths + parameter: object + Parameter passed to `dask_setup` and `dask_teardown` + argv: [string] + List of string arguments passed to click-configurable `dask_setup`. + file_dir: string + Path of a directory where files should be copied + """ + + imported_modules = _import_modules(names, file_dir=file_dir) + + for name, interface in imported_modules.items(): + dask_setup = interface.get("dask_setup", None) + dask_teardown = interface.get("dask_teardown", None) + + if dask_setup: + if isinstance(dask_setup, click.Command): + context = dask_setup.make_context("dask_setup", list(argv), allow_extra_args=False) + dask_setup.callback(parameter, *context.args, **context.params) + else: + dask_setup(parameter) + + if interface["dask_teardown"]: + atexit.register(interface["dask_teardown"], parameter) diff --git a/distributed/worker.py b/distributed/worker.py index ce2e1629786..bec3606d362 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -82,7 +82,7 @@ def __init__(self, scheduler_ip=None, scheduler_port=None, services=None, service_ports=None, name=None, reconnect=True, memory_limit='auto', executor=None, resources=None, silence_logs=None, - death_timeout=None, preload=(), security=None, + death_timeout=None, preload=(), preload_argv=[], security=None, contact_address=None, memory_monitor_interval=200, **kwargs): self._setup_logging() @@ -102,6 +102,7 @@ def __init__(self, scheduler_ip=None, scheduler_port=None, self.available_resources = (resources or {}).copy() self.death_timeout = death_timeout self.preload = preload + self.preload_argv = preload_argv, self.contact_address = contact_address self.memory_monitor_interval = memory_monitor_interval if silence_logs: @@ -345,7 +346,7 @@ def _start(self, addr_or_port=0): protocol, listen_host = listen_host.split('://') self.name = self.name or self.address - preload_modules(self.preload, parameter=self, file_dir=self.local_dir) + preload_modules(self.preload, parameter=self, file_dir=self.local_dir, argv=self.preload_argv) # Services listen on all addresses # Note Nanny is not a "real" service, just some metadata # passed in service_ports... diff --git a/docs/source/setup.rst b/docs/source/setup.rst index dc3be93ceeb..13556f8a2b2 100644 --- a/docs/source/setup.rst +++ b/docs/source/setup.rst @@ -283,12 +283,19 @@ Customizing initialization -------------------------- Both ``dask-scheduler`` and ``dask-worker`` support a ``--preload`` option that -allows custom initialization of each scheduler/worker respectively. A module -or python file passed as a ``--preload`` value is guaranteed to be imported -before establishing any connection. A ``dask_setup(service)`` function is called -if found, with a ``Scheduler`` or ``Worker`` instance as the argument. As the +allows custom initialization of each scheduler/worker respectively. A module or +python file passed as a ``--preload`` value is guaranteed to be imported before +establishing any connection. A ``dask_setup(service)`` function is called if +found, with a ``Scheduler`` or ``Worker`` instance as the argument. As the service stops, ``dask_teardown(service)`` is called if present. +To support additional configuration a single ``--preload`` module may register +additional command-line arguments by exposing ``dask_setup`` as a Click_ +command. This command will be used to parse additional arguments provided to +``dask-worker`` or ``dask-scheduler`` and will be called before service +initialization. + +.. _Click: http://click.pocoo.org/ As an example, consider the following file that creates a @@ -297,17 +304,27 @@ As an example, consider the following file that creates a .. code-block:: python # scheduler-setup.py + import click + from distributed.diagnostics.plugin import SchedulerPlugin class MyPlugin(SchedulerPlugin): - def add_worker(self, scheduler=None, worker=None, **kwargs): - print("Added a new worker at", worker) + def __init__(self, print_count): + self.print_count = print_count + SchedulerPlugin.__init__(self) - def dask_setup(scheduler): - plugin = MyPlugin() + def add_worker(self, scheduler=None, worker=None, **kwargs): + print("Added a new worker at:", worker) + if self.print_count and scheduler is not None + print("Total workers:", len(scheduler.workers)) + + @click.command + @click.option("--print-count/--no-print-count", default=False) + def dask_setup(scheduler, print_count): + plugin = MyPlugin(print_count) scheduler.add_plugin(plugin) We can then run this preload script by referring to its filename (or module name if it is on the path) when we start the scheduler:: - dask-scheduler --preload scheduler-setup.py + dask-scheduler --preload scheduler-setup.py --print-count