Skip to content
17 changes: 10 additions & 7 deletions distributed/cli/dask_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from distributed.utils import get_ip_interface, ignoring
from distributed.cli.utils import (check_python_3, install_signal_handlers,
uri_from_host_port)
from distributed.preloading import preload_modules
from distributed.preloading import preload_modules, validate_preload_argv
from distributed.proctitle import (enable_proctitle_on_children,
enable_proctitle_on_current)

Expand All @@ -26,7 +26,7 @@
pem_file_option_type = click.Path(exists=True, resolve_path=True)


@click.command()
@click.command(context_settings=dict(ignore_unknown_options=True))
@click.option('--host', type=str, default='',
help="URI, IP or hostname of this server")
@click.option('--port', type=int, default=None, help="Serving port")
Expand Down Expand Up @@ -58,11 +58,14 @@
"cluster is on a shared network file system.")
@click.option('--local-directory', default='', type=str,
help="Directory to place scheduler files")
@click.option('--preload', type=str, multiple=True,
help='Module that should be loaded by each worker process like "foo.bar" or "/path/to/foo.py"')
@click.option('--preload', type=str, multiple=True, is_eager=True,
help='Module that should be loaded by each worker process '
'like "foo.bar" or "/path/to/foo.py"')
@click.argument('preload_argv', nargs=-1,
type=click.UNPROCESSED, callback=validate_preload_argv)
def main(host, port, bokeh_port, show, _bokeh, bokeh_whitelist, bokeh_prefix,
use_xheaders, pid_file, scheduler_file, interface,
local_directory, preload, tls_ca_file, tls_cert, tls_key):
use_xheaders, pid_file, scheduler_file, interface,
local_directory, preload, preload_argv, tls_ca_file, tls_cert, tls_key):

enable_proctitle_on_current()
enable_proctitle_on_children()
Expand Down Expand Up @@ -119,7 +122,7 @@ def del_pid_file():
scheduler_file=scheduler_file,
security=sec)
scheduler.start(addr)
preload_modules(preload, parameter=scheduler, file_dir=local_directory)
preload_modules(preload, parameter=scheduler, file_dir=local_directory, argv=preload_argv)

logger.info('Local Directory: %26s', local_directory)
logger.info('-' * 47)
Expand Down
12 changes: 8 additions & 4 deletions distributed/cli/dask_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from distributed.cli.utils import (check_python_3, uri_from_host_port,
install_signal_handlers)
from distributed.comm import get_address_host_port
from distributed.preloading import validate_preload_argv
from distributed.proctitle import (enable_proctitle_on_children,
enable_proctitle_on_current)

Expand All @@ -27,7 +28,7 @@
pem_file_option_type = click.Path(exists=True, resolve_path=True)


@click.command()
@click.command(context_settings=dict(ignore_unknown_options=True))
@click.argument('scheduler', type=str, required=False)
@click.option('--tls-ca-file', type=pem_file_option_type, default=None,
help="CA cert(s) file for TLS (in PEM format)")
Expand Down Expand Up @@ -89,14 +90,16 @@
help="Seconds to wait for a scheduler before closing")
@click.option('--bokeh-prefix', type=str, default=None,
help="Prefix for the bokeh app")
@click.option('--preload', type=str, multiple=True,
@click.option('--preload', type=str, multiple=True, is_eager=True,
help='Module that should be loaded by each worker process '
'like "foo.bar" or "/path/to/foo.py"')
@click.argument('preload_argv', nargs=-1,
type=click.UNPROCESSED, callback=validate_preload_argv)
def main(scheduler, host, worker_port, listen_address, contact_address,
nanny_port, nthreads, nprocs, nanny, name,
memory_limit, pid_file, reconnect, resources, bokeh,
bokeh_port, local_directory, scheduler_file, interface,
death_timeout, preload, bokeh_prefix, tls_ca_file,
death_timeout, preload, preload_argv, bokeh_prefix, tls_ca_file,
tls_cert, tls_key):
enable_proctitle_on_current()
enable_proctitle_on_children()
Expand Down Expand Up @@ -212,7 +215,8 @@ def del_pid_file():
services=services, loop=loop, resources=resources,
memory_limit=memory_limit, reconnect=reconnect,
local_dir=local_directory, death_timeout=death_timeout,
preload=preload, security=sec, contact_address=contact_address,
preload=preload, preload_argv=preload_argv,
security=sec, contact_address=contact_address,
name=name if nprocs == 1 else name + '-' + str(i),
**kwargs)
for i in range(nprocs)]
Expand Down
61 changes: 61 additions & 0 deletions distributed/cli/tests/test_dask_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,3 +296,64 @@ def check_scheduler():
c.scheduler.address
finally:
shutil.rmtree(tmpdir)


PRELOAD_COMMAND_TEXT = """
import click
_config = {}

@click.command()
@click.option("--passthrough", type=str, default="default")
def dask_setup(scheduler, passthrough):
_config["passthrough"] = passthrough

def get_passthrough():
return _config["passthrough"]
"""


def test_preload_command(loop):

def check_passthrough():
import passthrough_info
return passthrough_info.get_passthrough()

tmpdir = tempfile.mkdtemp()
try:
path = os.path.join(tmpdir, 'passthrough_info.py')
with open(path, 'w') as f:
f.write(PRELOAD_COMMAND_TEXT)

with tmpfile() as fn:
print(fn)
with popen(['dask-scheduler', '--scheduler-file', fn,
'--preload', path, "--passthrough", "foobar"]):
with Client(scheduler_file=fn, loop=loop) as c:
assert c.run_on_scheduler(check_passthrough) == \
"foobar"
finally:
shutil.rmtree(tmpdir)


def test_preload_command_default(loop):

def check_passthrough():
import passthrough_info
return passthrough_info.get_passthrough()

tmpdir = tempfile.mkdtemp()
try:
path = os.path.join(tmpdir, 'passthrough_info.py')
with open(path, 'w') as f:
f.write(PRELOAD_COMMAND_TEXT)

with tmpfile() as fn2:
print(fn2)
with popen(['dask-scheduler', '--scheduler-file', fn2,
'--preload', path], stdout=sys.stdout, stderr=sys.stderr):
with Client(scheduler_file=fn2, loop=loop) as c:
assert c.run_on_scheduler(check_passthrough) == \
"default"

finally:
shutil.rmtree(tmpdir)
5 changes: 4 additions & 1 deletion distributed/nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, scheduler_ip=None, scheduler_port=None,
ncores=None, loop=None, local_dir=None, services=None,
name=None, memory_limit='auto', reconnect=True,
validate=False, quiet=False, resources=None, silence_logs=None,
death_timeout=None, preload=(), security=None,
death_timeout=None, preload=(), preload_argv=[], security=None,
contact_address=None, listen_address=None, **kwargs):
if scheduler_file:
cfg = json_load_robust(scheduler_file)
Expand All @@ -60,6 +60,8 @@ def __init__(self, scheduler_ip=None, scheduler_port=None,
self.resources = resources
self.death_timeout = death_timeout
self.preload = preload
self.preload_argv = preload_argv

self.contact_address = contact_address
self.memory_terminate_fraction = config.get('worker-memory-terminate', 0.95)

Expand Down Expand Up @@ -203,6 +205,7 @@ def instantiate(self, comm=None):
silence_logs=self.silence_logs,
death_timeout=self.death_timeout,
preload=self.preload,
preload_argv=self.preload_argv,
security=self.security,
contact_address=self.contact_address),
worker_start_args=(start_arg,),
Expand Down
110 changes: 98 additions & 12 deletions distributed/preloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,79 @@
import os
import shutil
import sys
import filecmp
from importlib import import_module

import click

from .utils import import_file

logger = logging.getLogger(__name__)


def preload_modules(names, parameter=None, file_dir=None):
""" Imports modules, handles `dask_setup` and `dask_teardown` functions
def validate_preload_argv(ctx, param, value):
"""Click option callback providing validation of preload subcommand arguments."""
if not value and not ctx.params.get("preload", None):
# No preload argv provided and no preload modules specified.
return value

if value and not ctx.params.get("preload", None):
# Report a usage error matching standard click error conventions.
unexpected_args = [v for v in value if v.startswith("-")]
for a in unexpected_args:
raise click.NoSuchOption(a)
raise click.UsageError(
"Got unexpected extra argument%s: (%s)" %
("s" if len(value) > 1 else "", " ".join(value))
)

preload_modules = _import_modules(ctx.params.get("preload"))

preload_commands = [
m["dask_setup"] for m in preload_modules.values()
if isinstance(m["dask_setup"], click.Command)
]

if len(preload_commands) > 1:
raise click.UsageError(
"Multiple --preload modules with click-configurable setup: %s" %
list(preload_modules.keys()))

if value and not preload_commands:
raise click.UsageError(
"Unknown argument specified: %r Was click-configurable --preload target provided?")
if not preload_commands:
return value
else:
preload_command = preload_commands[0]

ctx = click.Context(preload_command, allow_extra_args=False)
preload_command.parse_args(ctx, list(value))

return value


def _import_modules(names, file_dir=None):
""" Imports modules and extracts preload interface functions.

Imports modules specified by names and extracts 'dask_setup'
and 'dask_teardown' if present.


Parameters
----------
names: list of strings
Module names or file paths
parameter: object
Parameter passed to `dask_setup` and `dask_teardown`
file_dir: string
Path of a directory where files should be copied

Returns
-------
Nest dict of names to extracted module interface components if present
in imported module.
"""
result_modules = {}

for name in names:
# import
if name.endswith(".py"):
Expand All @@ -30,7 +84,8 @@ def preload_modules(names, parameter=None, file_dir=None):
basename = os.path.basename(name)
copy_dst = os.path.join(file_dir, basename)
if os.path.exists(copy_dst):
logger.error("File name collision: %s", basename)
if not filecmp.cmp(name, copy_dst):
logger.error("File name collision: %s", basename)
shutil.copy(name, copy_dst)
module = import_file(copy_dst)[0]
else:
Expand All @@ -42,10 +97,41 @@ def preload_modules(names, parameter=None, file_dir=None):
import_module(name)
module = sys.modules[name]

# handle special functions
dask_setup = getattr(module, 'dask_setup', None)
dask_teardown = getattr(module, 'dask_teardown', None)
if dask_setup is not None:
dask_setup(parameter)
if dask_teardown is not None:
atexit.register(dask_teardown, parameter)
result_modules[name] = {
attrname : getattr(module, attrname, None)
for attrname in ("dask_setup", "dask_teardown")
}

return result_modules


def preload_modules(names, parameter=None, file_dir=None, argv=None):
""" Imports modules, handles `dask_setup` and `dask_teardown`.

Parameters
----------
names: list of strings
Module names or file paths
parameter: object
Parameter passed to `dask_setup` and `dask_teardown`
argv: [string]
List of string arguments passed to click-configurable `dask_setup`.
file_dir: string
Path of a directory where files should be copied
"""

imported_modules = _import_modules(names, file_dir=file_dir)

for name, interface in imported_modules.items():
dask_setup = interface.get("dask_setup", None)
dask_teardown = interface.get("dask_teardown", None)

if dask_setup:
if isinstance(dask_setup, click.Command):
context = dask_setup.make_context("dask_setup", list(argv), allow_extra_args=False)
dask_setup.callback(parameter, *context.args, **context.params)
else:
dask_setup(parameter)

if interface["dask_teardown"]:
atexit.register(interface["dask_teardown"], parameter)
5 changes: 3 additions & 2 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(self, scheduler_ip=None, scheduler_port=None,
services=None, service_ports=None, name=None,
reconnect=True, memory_limit='auto',
executor=None, resources=None, silence_logs=None,
death_timeout=None, preload=(), security=None,
death_timeout=None, preload=(), preload_argv=[], security=None,
contact_address=None, memory_monitor_interval=200, **kwargs):

self._setup_logging()
Expand All @@ -102,6 +102,7 @@ def __init__(self, scheduler_ip=None, scheduler_port=None,
self.available_resources = (resources or {}).copy()
self.death_timeout = death_timeout
self.preload = preload
self.preload_argv = preload_argv,
self.contact_address = contact_address
self.memory_monitor_interval = memory_monitor_interval
if silence_logs:
Expand Down Expand Up @@ -345,7 +346,7 @@ def _start(self, addr_or_port=0):
protocol, listen_host = listen_host.split('://')

self.name = self.name or self.address
preload_modules(self.preload, parameter=self, file_dir=self.local_dir)
preload_modules(self.preload, parameter=self, file_dir=self.local_dir, argv=self.preload_argv)
# Services listen on all addresses
# Note Nanny is not a "real" service, just some metadata
# passed in service_ports...
Expand Down
Loading