-
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Description
In #7093 we discussed a change in the collection protocol to allow rebuilding a collection using __dask_postpersist__ after a change in the top-level keys. It revolves around formalizing the concept of collection "name", which is the first element of all the top-level keys, as returned by __dask_keys__, which must be identical across all keys.
As it turns out, xarray.Dataset objects don't and can't do it. They are dask collections which wrap around multiple, independent da.Array:
>>> import xarray
>>> ds = xarray.Dataset(data_vars={"d1": ("x", [1, 2]), "d2": ("x", [3, 4])}).chunk(1)
>>> ds.__dask_keys__()
[[('xarray-d1-b79ee2b65d520e2a3ef23843be98dae8', 0),
('xarray-d1-b79ee2b65d520e2a3ef23843be98dae8', 1)],
[('xarray-d2-0f76c2d28d71ff0e28b317e3b9ee37b5', 0),
('xarray-d2-0f76c2d28d71ff0e28b317e3b9ee37b5', 1)]]
>>> ds.__dask_layers__()
('xarray-d1-b79ee2b65d520e2a3ef23843be98dae8',
'xarray-d2-0f76c2d28d71ff0e28b317e3b9ee37b5')Same applies to xarray.DataArray with non-index coords.
Under the current design, it is impossible to pass a Dataset with 2+ dask variables in it through one of the new functions in dask.graph_manipulation (#7109).
Proposed design
- Remove the newly-introduced requirement that all keys returned by
__dask_keys__must share the same name - Change the signature of the callable returned by
__dask_postpersist__from
def rebuild(dsk: Mapping, *args, name: str = None):to
def rebuild(dsk: Mapping, *args, rename: Mapping[str, str] = None):Where rename is a mapping of {old name: new name}.
The function dask.base.get_collection_name(coll) -> str, which currently fetches the name from the first key returned by __dask_keys__, will change to get_collection_names(coll) -> set[str] and will iterate through all keys.
Discarded designs
- Change
xarray.DataArrayandxarray.Datasetnot to be dask collections anymore. Instead, add to them a hook fordask.base.unpack_collectionsto let it extract the wrapped da.Array objects and then rebuild them afterwards. This idea was discarded as it would be a lot more invasive than the above. - Change all the dask keys in a
xarray.DataArrayandxarray.Datasetto be in the format(<Dataset name>, <da.Array name>, idx, idx, ...). This would break absolutely everything in thedask.arraymodule, add gratuitous extra keys to the dask graph that are just there to rename the top-level keys, and add a ton of complication and overhead in xarray too. - Add a new method to the collection protocol,
__dask_names__, to save the effort of cycling through the whole output of__dask_keys__. I think it would be overkill; I benchmarked it takes ~2ms to run through a 2-dimensional da.Array of 75 GiB and 10k chunks.
XREF pydata/xarray#4884
CC @mrocklin @jrbourbeau @shoyer @jcrist @eriknw @keewis