-
-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Description
I have a 25 GiB xarray.DataArray with 220 MiB chunks, 117 top-level keys, and an underlying Dask graph of 74,500 keys in total (after optimization). Overall I would call it a sizeable computation, but not uncommon.
Calling .to_zarr(..., compute=False), writing to local SSD, takes several minutes to return the Delayed object. Not to compute it - just to define the graph!
Reproducer
import xarray
import dask.array as da
from dask.core import flatten
# Build a synthetic array with similar number of top-level keys
# and underlying graph size as the real one
a = da.zeros((10, 10), chunks=(1, 1))
for _ in range(30):
a = a @ a
a = xarray.DataArray(a)
len(list(flatten(a.__dask_keys__()))) # 100
len(a.__dask_graph__()) # 72100
%timeit a.to_zarr("test.zarr", compute=False) # 1min 30sWhat is happening
xarray.DataArray.to_zarr(..., compute=False)generates one Delayed object per zarr write task, then collects all the output None's of the write tasks with a dummy single Delayed to be later computed by the end user:
https://github.com/pydata/xarray/blob/a8732fc07bc98d92eacc5ed441c7fbdfce396c2f/xarray/backends/api.py#L132-L135Delayed.__call__internally invokesProhibitReuse.__dask_graph__(),
Lines 179 to 182 in 6ac1553
expr2 = ProhibitReuse(collections_to_expr(expr).finalize_compute()) finalized = expr2.optimize() # FIXME: Make this also go away dsk = finalized.__dask_graph__() - which in turn generates a
subsdict of 72,000 elements and callsTask.substitute(subs)for each of its 72,000 tasks,
Lines 1391 to 1418 in 6ac1553
def __dask_graph__(self): try: from distributed.shuffle._core import P2PBarrierTask except ModuleNotFoundError: P2PBarrierTask = type(None) dsk = convert_legacy_graph(self.expr.__dask_graph__()) subs = {old_key: self._modify_keys(old_key) for old_key in dsk} dsk2 = {} for old_key, new_key in subs.items(): t = dsk[old_key] if isinstance(t, P2PBarrierTask): warnings.warn( "Cannot block reusing for graphs including a " "P2PBarrierTask. This may cause unexpected results. " "This typically happens when converting a dask " "DataFrame to delayed objects.", UserWarning, ) return dsk dsk2[new_key] = Task( new_key, ProhibitReuse._identity, t.substitute(subs), ) dsk2.update(dsk) return dsk2 - each of these 72,000 calls in turn performs a full scan of the 72,000-elements
subsdict:
Lines 794 to 796 in 6ac1553
subs_filtered = { k: v for k, v in subs.items() if k in self.dependencies and k != v }
History
This is a regression introduced in 2025.4.0 by #11881, which in turn started triggering for this specific use case the O(n^2) behaviour introduced much earlier in 2024.12.0 by #11568. I expect there to be more concrete use cases where this issue is triggered, but I did not spend time researching them.
FYI @fjetter
xarray 2026.2.0, dask 2025.3.0: 0.71s
xarray 2026.2.0, dask 2025.4.0: 1m28s
xarray 2026.2.0, dask 2026.1.2: 1m30s
xarray 2026.2.0, dask git tip + #12299: 1.00s