Make dask.array.utils functions more generic to other Dask Arrays#10676
Make dask.array.utils functions more generic to other Dask Arrays#10676crusaderky merged 12 commits intodask:mainfrom
Conversation
This will be helpful for dask_expr.array, if we do that soon.
d0fc005 to
15c7c0f
Compare
dask/base.py
Outdated
| if "xarray" in type(x).__module__: | ||
| return x.__dask_graph__() is not None | ||
| else: | ||
| return hasattr(x, "__dask_graph__") |
There was a problem hiding this comment.
@phofl this change is actually kinda important. We were generating the task graph before whenever calling is_dask_collection. Maybe caching was saving us here, but this could be unpleasantly expensive for dask-expr.
The history here is that xarray collections have a __dask_graph__ attribute, but are not always dask collections (they sometimes hold dask arrays but sometimes just have numpy arrays).
There was a problem hiding this comment.
Yeah I remember that we discussed this a while back and ended at revisiting if necessary because caching worked fine back then. That change is nice anyway.
There was a problem hiding this comment.
This seems pretty expensive for the xarray code path which really should just be something like
any(variable._data for variable in xarray_obj.variables)
Can we add something like __dask_is_present__?
There was a problem hiding this comment.
We can add a new protocol. We'll have to do this everywhere, but it seems like an ok thing to do. In the meantime, let's just special-case something for Xarray. I think that it's the only project today where this matters.
I started doing this:
if "xarray" in type(x).__module__:
import xarray
if isinstance(x, xarray.Dataset):
return any(variable._data for variable in x.variables)
if isinstance(x, xarray.DataArray):
return bool(x._data)But I'm not sure that _data exists? Maybe we want any(is_dask_collection(variable.data) ...)? Any suggestions?
There was a problem hiding this comment.
I suspect this is it:
from dask.base import is_dask_collection
def is_dask_xarray(x):
import xarray
if isinstance(x, xarray.Dataset):
return any(
is_dask_collection(variable._data) for _, variable in x.variables.items()
)
elif isinstance(x, xarray.DataArray):
return is_dask_collection(x.variable._data)
elif isinstance(x, xarray.Variable):
return is_dask_collection(x._data)
import xarray as xr
ds = xr.tutorial.open_dataset("air_temperature", chunks="auto")
computed = ds.compute()
assert is_dask_xarray(ds)
assert is_dask_xarray(ds.air)
assert is_dask_xarray(ds.air.variable)
assert not is_dask_xarray(computed)
assert not is_dask_xarray(computed.air)
assert not is_dask_xarray(computed.air.variable)There was a problem hiding this comment.
I also ran into the issue of materialization on __dask_graph__ when playing with the scheduler integration and I strongly recommend moving dask-expr away from this in favor of an explicit method doing that (that's what I am proposing in dask/dask-expr#294)
There was a problem hiding this comment.
@dcherian I've pushed up your changes. Thanks. Using the air_temperature dataset requires an optional dependency pooch, which I was hesitant to add to CI just for this (although maybe I should if it's very lightweight). Alternatively, can you recommend a way to construct a dataset using maybe da.random.random?
There was a problem hiding this comment.
@fjetter I'm not sure yet what you mean in the comment above. We need a way to determine if an object is a dask collection. I looked through the PR you mentioned and didn't immediately something there that would solve this problem. Can you help me to understand?
There was a problem hiding this comment.
@mrocklin: ds = xr.Dataset({"air": (("time", "lat", "lon"), dask.array.ones((120, 120, 120)))})
|
@fjetter can I ask your team to handle this sometime next week? |
can you state what the expected outcome should be? I'm missing some context about what this change is enabling |
These changes enable code sharing between They are, I think, innocuous and should be easy to merge. |
dask/base.py
Outdated
| elif isinstance(x, xarray.Variable): | ||
| return is_dask_collection(x._data) | ||
| else: | ||
| raise TypeError("Unfamiliar with xarray type", type(x)) |
There was a problem hiding this comment.
This is suppressed a few lines below.
Can we reduce the scope of the try...except block to just the type(x).__module__ line?
dask/base.py
Outdated
| and callable(x.__dask_graph__) | ||
| and not isinstance(x, type) | ||
| ) | ||
|
|
There was a problem hiding this comment.
This PR breaks pint, because just like xarray it violates the dask collection protocol by defining a __dask_graph__() method that may return None:
>>> import dask.array as da, pint
>>> q = pint.Quantity(1, "m")
>>> q.__dask_graph__()
(None)
>>> q = pint.Quantity(da.zeros(5), "m")
>>> q.__dask_graph__()
HighLevelGraph with 1 layers.
<dask.highlevelgraph.HighLevelGraph object at 0x7f043c3889a0>
0. zeros_like-902bf3bcabaabb3fedce3f347fb26081There may be more third-party libraries in the wild that do the same.
A blocklist approach would be more robust:
if (
isinstance(x, type)
or not hasattr(x, "__dask_graph__")
or not callable(x.__dask_graph__)
):
return False
pkg_name = getattr(type(x), "__module__", "").split(".")[0]
# Add here other third-party packages where you don't want
# to call the `__dask_graph__` method - typically because it's expensive
if pkg_name == "dask_expr":
return True
# xarray, pint, and possibly other libraries return None when they wrap a non-dask object
return x.__dask_graph__() is not NoneThere was a problem hiding this comment.
The above, however, will mean that we call the __dask_graph__() method of dask_expr.array objects when they're wrapped by xarray or pint. So I would suggest either:
- hybrid allowlist+blocklist, with final fallback on calling
__dask_graph__() - change the
__dask_graph__()method in dask_expr to be trivial (e.g. write a Mapping subclass that is trivial to initialize and is lazy until you actually try to access its contents)
|
|
||
| x = delayed(1) + 2 | ||
| assert is_dask_collection(x) | ||
| assert not is_dask_collection(2) |
dask/base.py
Outdated
| if isinstance(x, xarray.Dataset): | ||
| return any(is_dask_collection(v._data) for _, v in x.variables.items()) | ||
| elif isinstance(x, xarray.DataArray): | ||
| return is_dask_collection(x.variable._data) |
There was a problem hiding this comment.
This is returning a false negative when you have an eager variable and lazy non-index coordinates
>>> import dask.array as da, xarray
>>> xarray.DataArray([1,2], dims=["x"], coords={"x": [10, 20], "x2": ("x", da.zeros(2))})
<xarray.DataArray (x: 2)>
array([1, 2])
Coordinates:
* x (x) int64 10 20
x2 (x) float64 dask.array<chunksize=(2,), meta=np.ndarray>
dask/base.py
Outdated
| import xarray | ||
|
|
||
| if isinstance(x, xarray.Dataset): | ||
| return any(is_dask_collection(v._data) for _, v in x.variables.items()) |
There was a problem hiding this comment.
| return any(is_dask_collection(v._data) for _, v in x.variables.items()) | |
| return any(is_dask_collection(v._data) for v in x.variables.values()) |
dask/base.py
Outdated
| elif isinstance(x, xarray.DataArray): | ||
| return is_dask_collection(x.variable._data) |
There was a problem hiding this comment.
| elif isinstance(x, xarray.DataArray): | |
| return is_dask_collection(x.variable._data) | |
| elif isinstance(x, xarray.DataArray): | |
| return is_dask_collection(x.variable._data) or any(is_dask_collection(var._data) for var in x._coords.values()) |
|
Can we go back to the design table for a bit? What this PR tries to achieveChiefly, this PR tries to future-proof In main, For xarray, today it's a cheap-ish layer merge: The problems start if you call it on a dask_expr collection, which triggers a very expensive materialization of the whole graph: https://github.com/dask-contrib/dask-expr/blob/0fab3c19de0e54085119e182de92cbe5ec25f479/dask_expr/_core.py#L429-L445 This PR wants to avoid materialization in dask-expr by changing Why it's not trivialThe problem lies within wrapper libraries - xarray and pint are two, but there may be more I'm not aware of and potentially bespoke, unpublished ones too. Both xarray and pint define a As of today, the only ways for a. call their b. special-case them in I suggest two alternative designs: d. define a new API endpoint in the dask protocol, For example, for xarray.Dataset: def __dask_collections__(self) -> list[DaskCollection]:
from dask.base import is_dask_collection
return [v._data for v in self.variables.values() if is_dask_collection(v._data)] |
|
Offline comment from @fjetter:
|
|
In light of the above PRs, I'm switching from allowlist to blocklist, with the expectation that the PRs will be merged before wrapping dask_expr.Array inside xarray or pint becomes important. We can always re-add the allowlist later on. @mrocklin just wanted to make sure you're happy with this. |
phofl
left a comment
There was a problem hiding this comment.
Ideally you shouldn't reach optimise either, that's also expensive if you read from remote storage for example
lgtm otherwise
|
@mrocklin I'm merging this. If you have comments happy to open a follow-up PR to address them. |
|
Grand. Thanks!
…On Thu, Jan 4, 2024 at 9:00 AM crusaderky ***@***.***> wrote:
Merged #10676 <#10676> into main.
—
Reply to this email directly, view it on GitHub
<#10676 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AACKZTBKC6FOL2KDNJATLL3YM3N4TAVCNFSM6AAAAABAIS2SLSVHI2DSMVQWIX3LMV45UABCJFZXG5LFIV3GK3TUJZXXI2LGNFRWC5DJN5XDWMJRGM4TENJZG42DMMA>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
This will be helpful for dask_expr.array, if we do that soon.
I also ran
py.test dask/arrays/tests/test_array_core.pyon both this branch and main and notice no slowdown