-
-
Notifications
You must be signed in to change notification settings - Fork 757
Description
This is a cleanup of #3791
distributed 2.16.0, Linux
In short
The design of the "default client" is not thread-safe; this becomes apparent when one tries using the functionality inside a multithreaded worker.
Severity
Minor: a straightforward workaround exists.
Use case and POC
A dask task is executed many times in parallel across different threads of the same worker process.
The task retrieves a dask.Collection containing Futures using distributed.get_client().get_dataset(<name>), and then invokes its compute() method without explicitly specifying a scheduler.
from distributed import Client, LocalCluster, get_client
import dask.array
def task(_):
client = get_client()
arr = client.get_dataset("arr")
arr.compute()
def main():
with LocalCluster(1, threads_per_worker=2) as cluster:
with Client(cluster) as client:
arr = dask.array.ones(1).persist()
client.publish_dataset(arr=arr)
futures = client.map(task, list(range(8)))
client.gather(futures)
if __name__ == "__main__":
main()Current behaviour
The compute() method in thread 1 acquires the client created by get_client() in thread 2, and vice versa, causing ValueError: Inputs contain futures that were created by another client.
The issue disappears if
- One sets
threads_per_worker=1 - One creates a number of tasks that is much lower than the number of threads
Expected behaviour
dask.Collection.compute() always uses the same Client returned by distributed.get_client() within the same task, without needing to specify it explicitly.
This implies that either
- The default client global variable must be thread-local, or
- get_client() always returns the same client across all threads; such a client will need to use
run_coroutine_threadsafeunder the hood
Workaround
Use an explicit client in compute():
def task(_):
client = get_client()
arr = client.get_dataset("arr")
client.compute(arr)