|
26 | 26 |
|
27 | 27 | from dask import config, local |
28 | 28 | from dask.compatibility import _EMSCRIPTEN, _PY_VERSION |
29 | | -from dask.context import thread_state |
30 | 29 | from dask.core import flatten |
31 | 30 | from dask.core import get as simple_get |
32 | 31 | from dask.core import literal, quote |
@@ -1362,14 +1361,25 @@ def get_scheduler(get=None, scheduler=None, collections=None, cls=None): |
1362 | 1361 | elif isinstance(scheduler, str): |
1363 | 1362 | scheduler = scheduler.lower() |
1364 | 1363 |
|
| 1364 | + try: |
| 1365 | + from distributed import default_client |
| 1366 | + |
| 1367 | + default_client() |
| 1368 | + client_available = True |
| 1369 | + except (ImportError, ValueError): |
| 1370 | + client_available = False |
1365 | 1371 | if scheduler in named_schedulers: |
1366 | | - if config.get("scheduler", None) in ("dask.distributed", "distributed"): |
| 1372 | + if client_available: |
1367 | 1373 | warnings.warn( |
1368 | 1374 | "Running on a single-machine scheduler when a distributed client " |
1369 | 1375 | "is active might lead to unexpected results." |
1370 | 1376 | ) |
1371 | 1377 | return named_schedulers[scheduler] |
1372 | 1378 | elif scheduler in ("dask.distributed", "distributed"): |
| 1379 | + if not client_available: |
| 1380 | + raise RuntimeError( |
| 1381 | + f"Requested {scheduler} scheduler but no Client active." |
| 1382 | + ) |
1373 | 1383 | from distributed.worker import get_client |
1374 | 1384 |
|
1375 | 1385 | return get_client().get |
@@ -1397,14 +1407,16 @@ def get_scheduler(get=None, scheduler=None, collections=None, cls=None): |
1397 | 1407 | if config.get("get", None): |
1398 | 1408 | raise ValueError(get_err_msg) |
1399 | 1409 |
|
1400 | | - if getattr(thread_state, "key", False): |
1401 | | - from distributed.worker import get_worker |
1402 | | - |
1403 | | - return get_worker().client.get |
1404 | | - |
1405 | 1410 | if cls is not None: |
1406 | 1411 | return cls.__dask_scheduler__ |
1407 | 1412 |
|
| 1413 | + try: |
| 1414 | + from distributed import get_client |
| 1415 | + |
| 1416 | + return get_client().get |
| 1417 | + except (ImportError, ValueError): |
| 1418 | + pass |
| 1419 | + |
1408 | 1420 | if collections: |
1409 | 1421 | collections = [c for c in collections if c is not None] |
1410 | 1422 | if collections: |
|
0 commit comments