Skip to content

Commit dfcefd8

Browse files
authored
Use distributed default clients even if no config is set (#9808)
1 parent 6d639d2 commit dfcefd8

3 files changed

Lines changed: 50 additions & 20 deletions

File tree

dask/base.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626

2727
from dask import config, local
2828
from dask.compatibility import _EMSCRIPTEN, _PY_VERSION
29-
from dask.context import thread_state
3029
from dask.core import flatten
3130
from dask.core import get as simple_get
3231
from dask.core import literal, quote
@@ -1362,14 +1361,25 @@ def get_scheduler(get=None, scheduler=None, collections=None, cls=None):
13621361
elif isinstance(scheduler, str):
13631362
scheduler = scheduler.lower()
13641363

1364+
try:
1365+
from distributed import default_client
1366+
1367+
default_client()
1368+
client_available = True
1369+
except (ImportError, ValueError):
1370+
client_available = False
13651371
if scheduler in named_schedulers:
1366-
if config.get("scheduler", None) in ("dask.distributed", "distributed"):
1372+
if client_available:
13671373
warnings.warn(
13681374
"Running on a single-machine scheduler when a distributed client "
13691375
"is active might lead to unexpected results."
13701376
)
13711377
return named_schedulers[scheduler]
13721378
elif scheduler in ("dask.distributed", "distributed"):
1379+
if not client_available:
1380+
raise RuntimeError(
1381+
f"Requested {scheduler} scheduler but no Client active."
1382+
)
13731383
from distributed.worker import get_client
13741384

13751385
return get_client().get
@@ -1397,14 +1407,16 @@ def get_scheduler(get=None, scheduler=None, collections=None, cls=None):
13971407
if config.get("get", None):
13981408
raise ValueError(get_err_msg)
13991409

1400-
if getattr(thread_state, "key", False):
1401-
from distributed.worker import get_worker
1402-
1403-
return get_worker().client.get
1404-
14051410
if cls is not None:
14061411
return cls.__dask_scheduler__
14071412

1413+
try:
1414+
from distributed import get_client
1415+
1416+
return get_client().get
1417+
except (ImportError, ValueError):
1418+
pass
1419+
14081420
if collections:
14091421
collections = [c for c in collections if c is not None]
14101422
if collections:

dask/tests/test_base.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1557,19 +1557,6 @@ def test_get_scheduler():
15571557
assert get_scheduler() is None
15581558

15591559

1560-
def test_get_scheduler_with_distributed_active():
1561-
1562-
with dask.config.set(scheduler="dask.distributed"):
1563-
warning_message = (
1564-
"Running on a single-machine scheduler when a distributed client "
1565-
"is active might lead to unexpected results."
1566-
)
1567-
with pytest.warns(UserWarning, match=warning_message) as user_warnings_a:
1568-
get_scheduler(scheduler="threads")
1569-
get_scheduler(scheduler="sync")
1570-
assert len(user_warnings_a) == 2
1571-
1572-
15731560
def test_callable_scheduler():
15741561
called = [False]
15751562

dask/tests/test_distributed.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import dask
2323
import dask.bag as db
2424
from dask import compute, delayed, persist
25+
from dask.base import get_scheduler
2526
from dask.blockwise import Blockwise
2627
from dask.delayed import Delayed
2728
from dask.distributed import futures_of, wait
@@ -804,3 +805,33 @@ def test_set_index_no_resursion_error(c):
804805
ddf.compute()
805806
except RecursionError:
806807
pytest.fail("dd.set_index triggered a recursion error")
808+
809+
810+
def test_get_scheduler_without_distributed_raises():
811+
msg = "no Client"
812+
with pytest.raises(RuntimeError, match=msg):
813+
get_scheduler(scheduler="dask.distributed")
814+
815+
with pytest.raises(RuntimeError, match=msg):
816+
get_scheduler(scheduler="distributed")
817+
818+
819+
def test_get_scheduler_with_distributed_active(c):
820+
assert get_scheduler() == c.get
821+
warning_message = (
822+
"Running on a single-machine scheduler when a distributed client "
823+
"is active might lead to unexpected results."
824+
)
825+
with pytest.warns(UserWarning, match=warning_message) as user_warnings_a:
826+
get_scheduler(scheduler="threads")
827+
get_scheduler(scheduler="sync")
828+
assert len(user_warnings_a) == 2
829+
830+
831+
def test_get_scheduler_with_distributed_active_reset_config(c):
832+
assert get_scheduler() == c.get
833+
with dask.config.set(scheduler="threads"):
834+
with pytest.warns(UserWarning):
835+
assert get_scheduler() != c.get
836+
with dask.config.set(scheduler=None):
837+
assert get_scheduler() == c.get

0 commit comments

Comments
 (0)