Skip to content

Race condition in default client on multithreaded workers #3827

@crusaderky

Description

@crusaderky

This is a cleanup of #3791

distributed 2.16.0, Linux

In short

The design of the "default client" is not thread-safe; this becomes apparent when one tries using the functionality inside a multithreaded worker.

Severity

Minor: a straightforward workaround exists.

Use case and POC

A dask task is executed many times in parallel across different threads of the same worker process.
The task retrieves a dask.Collection containing Futures using distributed.get_client().get_dataset(<name>), and then invokes its compute() method without explicitly specifying a scheduler.

from distributed import Client, LocalCluster, get_client
import dask.array


def task(_):
    client = get_client()
    arr = client.get_dataset("arr")
    arr.compute()


def main():
    with LocalCluster(1, threads_per_worker=2) as cluster:
        with Client(cluster) as client:
            arr = dask.array.ones(1).persist()
            client.publish_dataset(arr=arr)
            futures = client.map(task, list(range(8)))
            client.gather(futures)


if __name__ == "__main__":
    main()

Current behaviour

The compute() method in thread 1 acquires the client created by get_client() in thread 2, and vice versa, causing ValueError: Inputs contain futures that were created by another client.

The issue disappears if

  1. One sets threads_per_worker=1
  2. One creates a number of tasks that is much lower than the number of threads

Expected behaviour

dask.Collection.compute() always uses the same Client returned by distributed.get_client() within the same task, without needing to specify it explicitly.
This implies that either

  1. The default client global variable must be thread-local, or
  2. get_client() always returns the same client across all threads; such a client will need to use run_coroutine_threadsafe under the hood

Workaround

Use an explicit client in compute():

def task(_):
    client = get_client()
    arr = client.get_dataset("arr")
    client.compute(arr)

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething is broken

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions