Skip to content

xarray.Dataset can't respect the new collection protocol #7203

@crusaderky

Description

@crusaderky

In #7093 we discussed a change in the collection protocol to allow rebuilding a collection using __dask_postpersist__ after a change in the top-level keys. It revolves around formalizing the concept of collection "name", which is the first element of all the top-level keys, as returned by __dask_keys__, which must be identical across all keys.

As it turns out, xarray.Dataset objects don't and can't do it. They are dask collections which wrap around multiple, independent da.Array:

>>> import xarray
>>> ds = xarray.Dataset(data_vars={"d1": ("x", [1, 2]), "d2": ("x", [3, 4])}).chunk(1)
>>> ds.__dask_keys__()
[[('xarray-d1-b79ee2b65d520e2a3ef23843be98dae8', 0),
  ('xarray-d1-b79ee2b65d520e2a3ef23843be98dae8', 1)],
 [('xarray-d2-0f76c2d28d71ff0e28b317e3b9ee37b5', 0),
  ('xarray-d2-0f76c2d28d71ff0e28b317e3b9ee37b5', 1)]]
>>> ds.__dask_layers__()
('xarray-d1-b79ee2b65d520e2a3ef23843be98dae8',
 'xarray-d2-0f76c2d28d71ff0e28b317e3b9ee37b5')

Same applies to xarray.DataArray with non-index coords.

Under the current design, it is impossible to pass a Dataset with 2+ dask variables in it through one of the new functions in dask.graph_manipulation (#7109).

Proposed design

  1. Remove the newly-introduced requirement that all keys returned by __dask_keys__ must share the same name
  2. Change the signature of the callable returned by __dask_postpersist__ from
def rebuild(dsk: Mapping, *args, name: str = None):

to

def rebuild(dsk: Mapping, *args, rename: Mapping[str, str] = None):

Where rename is a mapping of {old name: new name}.

The function dask.base.get_collection_name(coll) -> str, which currently fetches the name from the first key returned by __dask_keys__, will change to get_collection_names(coll) -> set[str] and will iterate through all keys.

Discarded designs

  • Change xarray.DataArray and xarray.Dataset not to be dask collections anymore. Instead, add to them a hook for dask.base.unpack_collections to let it extract the wrapped da.Array objects and then rebuild them afterwards. This idea was discarded as it would be a lot more invasive than the above.
  • Change all the dask keys in a xarray.DataArray and xarray.Dataset to be in the format (<Dataset name>, <da.Array name>, idx, idx, ...). This would break absolutely everything in the dask.array module, add gratuitous extra keys to the dask graph that are just there to rename the top-level keys, and add a ton of complication and overhead in xarray too.
  • Add a new method to the collection protocol, __dask_names__, to save the effort of cycling through the whole output of __dask_keys__. I think it would be overkill; I benchmarked it takes ~2ms to run through a 2-dimensional da.Array of 75 GiB and 10k chunks.

XREF pydata/xarray#4884
CC @mrocklin @jrbourbeau @shoyer @jcrist @eriknw @keewis

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions