Skip to content

xarray's to_zarr(compute=False) has quadratic definition time #12298

@crusaderky

Description

@crusaderky

I have a 25 GiB xarray.DataArray with 220 MiB chunks, 117 top-level keys, and an underlying Dask graph of 74,500 keys in total (after optimization). Overall I would call it a sizeable computation, but not uncommon.

Calling .to_zarr(..., compute=False), writing to local SSD, takes several minutes to return the Delayed object. Not to compute it - just to define the graph!

Reproducer

import xarray
import dask.array as da
from dask.core import flatten

# Build a synthetic array with similar number of top-level keys 
# and underlying graph size as the real one
a = da.zeros((10, 10), chunks=(1, 1))
for _ in range(30):
    a = a @ a
a = xarray.DataArray(a)

len(list(flatten(a.__dask_keys__())))  # 100
len(a.__dask_graph__()) # 72100
%timeit a.to_zarr("test.zarr", compute=False)  # 1min 30s

What is happening

  1. xarray.DataArray.to_zarr(..., compute=False) generates one Delayed object per zarr write task, then collects all the output None's of the write tasks with a dummy single Delayed to be later computed by the end user:
    https://github.com/pydata/xarray/blob/a8732fc07bc98d92eacc5ed441c7fbdfce396c2f/xarray/backends/api.py#L132-L135
  2. Delayed.__call__ internally invokes ProhibitReuse.__dask_graph__(),

    dask/dask/delayed.py

    Lines 179 to 182 in 6ac1553

    expr2 = ProhibitReuse(collections_to_expr(expr).finalize_compute())
    finalized = expr2.optimize()
    # FIXME: Make this also go away
    dsk = finalized.__dask_graph__()
  3. which in turn generates a subs dict of 72,000 elements and calls Task.substitute(subs) for each of its 72,000 tasks,

    dask/dask/_expr.py

    Lines 1391 to 1418 in 6ac1553

    def __dask_graph__(self):
    try:
    from distributed.shuffle._core import P2PBarrierTask
    except ModuleNotFoundError:
    P2PBarrierTask = type(None)
    dsk = convert_legacy_graph(self.expr.__dask_graph__())
    subs = {old_key: self._modify_keys(old_key) for old_key in dsk}
    dsk2 = {}
    for old_key, new_key in subs.items():
    t = dsk[old_key]
    if isinstance(t, P2PBarrierTask):
    warnings.warn(
    "Cannot block reusing for graphs including a "
    "P2PBarrierTask. This may cause unexpected results. "
    "This typically happens when converting a dask "
    "DataFrame to delayed objects.",
    UserWarning,
    )
    return dsk
    dsk2[new_key] = Task(
    new_key,
    ProhibitReuse._identity,
    t.substitute(subs),
    )
    dsk2.update(dsk)
    return dsk2
  4. each of these 72,000 calls in turn performs a full scan of the 72,000-elements subs dict:

    dask/dask/_task_spec.py

    Lines 794 to 796 in 6ac1553

    subs_filtered = {
    k: v for k, v in subs.items() if k in self.dependencies and k != v
    }

History

This is a regression introduced in 2025.4.0 by #11881, which in turn started triggering for this specific use case the O(n^2) behaviour introduced much earlier in 2024.12.0 by #11568. I expect there to be more concrete use cases where this issue is triggered, but I did not spend time researching them.

FYI @fjetter

xarray 2026.2.0, dask 2025.3.0: 0.71s
xarray 2026.2.0, dask 2025.4.0: 1m28s
xarray 2026.2.0, dask 2026.1.2: 1m30s
xarray 2026.2.0, dask git tip + #12299: 1.00s

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs triageNeeds a response from a contributor

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions