-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
einsum for xarray #1968
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
einsum for xarray #1968
Conversation
shoyer
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice!
xarray/core/computation.py
Outdated
| subscripts = '' | ||
| for ds in input_core_dims: | ||
| subscripts += '...' + ''.join([dim_map[d] for d in ds]) + ',' | ||
| subscripts = subscripts[:-1] # remove last comma |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would probably be cleaner to build up subscripts as a list and use ','.join(subscripts_list) once at the end.
xarray/core/computation.py
Outdated
|
|
||
| result = apply_ufunc(np.einsum, subscripts, *arrays, | ||
| input_core_dims=[[]] + input_core_dims, | ||
| output_core_dims=output_core_dims, dask='allowed') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think dask='parallelized' is what you want here -- that generate the wrapper to do this with dask. This will require also determining the result data type, probably with dtypes.result_type or even np.result_type (we don't need support for non-numeric types in einsum, so I'm pretty sure NumPy's casting rules would work fine).
dask='allowed' would be appropriate if np.einsum already supported dask arrays (but it does not).
It's possible that a dask specific einsum could be much more efficient than the auto-generated wrapper here, but certainly this is good enough for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. I noticed that my current implementation is not very efficient for dask.
Maybe smaller number of input_core_dims is better for dask?
I think I need some improvement.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dask='parallelized' will only parallelize over broadcast dimensions, i.e., ones that don't appear in either input_core_dims or output_core_dims. So yes, it will probably be slow in many cases.
I'm still OK adding the non-optimal einsum for now and improving it later.
xarray/core/computation.py
Outdated
| if len(arrays) < 2: | ||
| raise TypeError('More than two arrays must be provided') | ||
|
|
||
| if any(not hasattr(arr, 'dims') for arr in arrays): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dataset also defines dims. It's probably better to explicitly use an isinstance() check.
xarray/core/dataarray.py
Outdated
| [d for d in other.dims if d not in dims]) | ||
|
|
||
| return type(self)(new_data, new_coords.variables, new_dims) | ||
| # backward compat: if there is no shared dimension, we rais an Errror |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be better to eliminate this special case. Then users can understand DataArray.dot as a simple short-cut for xarray.dot().
xarray/core/computation.py
Outdated
| arrays = args | ||
| if dims is None and isinstance(args[-1], (list, tuple, basestring)): | ||
| dims = args[-1] | ||
| arrays = args[:-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is better to require specifying dims with a keyword argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Our previous dot does not require dim. This assumes to sum over along all the common dimensions.
I think dim=None is not surprising.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, the default dims=None should be OK. I meant that dims should be a keyword only argument, not a required argument.
Here you are supporting xr.dot(a, b, 'x'), where 'x' denotes a dimension. I would require writing xr.dot(a, b, dim='x') or omitting dim altogether.
|
|
||
|
|
||
| def dot(*args, **kwargs): | ||
| """ dot(*arrays, dims=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dot(*arrays, *, dims=None) is the way to write this with Python 3's keyword only arguments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we would keep this as dot(*arrays, **kwargs) as we did not yet drop python 2 support?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was confused. def dot(*arrays, *, dims=None) is not valid syntax in Python 3, either. (There can only be one single *)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PEP3102 says we python 3 supports the form def dot(*arrays, dim=None).
| return apply_ufunc(duck_array_ops.tensordot, *arrays, dask='allowed', | ||
| input_core_dims=input_core_dims, | ||
| output_core_dims=output_core_dims, | ||
| kwargs={'axes': axes}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. I added a path for tensordot, which dask can compute more efficiently.
shoyer
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some feedback on the documentation (mostly grammar).
xarray/core/computation.py
Outdated
| ---------- | ||
| arrays: multiple DataArrays | ||
| arrays to compute. | ||
| dims: tuple of strings, optional |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
str or tuple of strings
xarray/core/computation.py
Outdated
| """ dot(*arrays, *, dims=None) | ||
| einsum for xarray object, but providing simpler interface based on | ||
| the array dimensions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should lead with a more general description. Maybe:
Generalized dot product for xarray objects. Like np.einsum, but
provides a simpler interface based on array dimensions.
xarray/core/computation.py
Outdated
| Parameters | ||
| ---------- | ||
| arrays: multiple DataArrays |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
*arrays: DataArray objects
xarray/core/computation.py
Outdated
| Parameters | ||
| ---------- | ||
| arrays: multiple DataArrays | ||
| arrays to compute. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Arrays
xarray/core/computation.py
Outdated
| arrays: multiple DataArrays | ||
| arrays to compute. | ||
| dims: tuple of strings, optional | ||
| Along which dimensions to be summed over. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which dimensions to sum over.
xarray/core/computation.py
Outdated
| Returns | ||
| ------- | ||
| dot: same type to input. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably should just "DataArray"?
xarray/core/computation.py
Outdated
|
|
||
| common_dims = set(arrays[0].dims) | ||
| for arr in arrays[1:]: | ||
| common_dims = common_dims.intersection(set(arr.dims)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a slightly different choice of default dimensions than np.einsum:
np.einsumsums over any dimensions that are defined in two over more inputs.- This sums only over dimensions that are defined on all inputs.
Should we switch this behavior to match einsum?
xarray/core/computation.py
Outdated
| dims=['a', 'b', 'c']) | ||
| >>> da_c = xr.DataArray(np.arange(5 * 6).reshape(5, 6), dims=['c', 'd']) | ||
| >>> dot(da_a, da_b, dims=['a', 'b']).dims |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These should use the full name xr.dot.
xarray/core/computation.py
Outdated
| dims = kwargs.pop('dims', None) | ||
|
|
||
| if len(arrays) < 2: | ||
| raise TypeError('More than one arrays must be provided') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this special case? If not, let's remove this. For consistency, it is nice to use the same logic even for edge cases when possible. This makes it easier to think about the function.
In this case, I think a dot product of 1 array would consistently defined by summing over dimensions listed explicitly in dims.
xarray/core/computation.py
Outdated
| dims = kwargs.pop('dims', None) | ||
| if len(kwargs) > 0: | ||
| raise TypeError('Invalid keyward arguments {} are given'.format( | ||
| kwargs.keys())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
W1655 dict.keys referenced when not iterating
xarray/core/computation.py
Outdated
| # find dimensions that exist in more than two arrays | ||
| whole_dims = [] | ||
| for arr in arrays: | ||
| whole_dims += [d for d in arr.dims] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might be a nice use for collections.Counter(), e.g.,
dim_counts = Counter():
for arr in arrays:
dim_counts.update(arr.dims)
| dims = [dims] | ||
|
|
||
| common_dims = set(arrays[0].dims) | ||
| all_dims = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would it work to make all_dims a set instead of a list? I think that would be slightly more efficient.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want to keep the occurrence order in all_dims, so that to move input_core_dims positions back to the original position.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, sounds good.
xarray/core/computation.py
Outdated
| if len(arrays) < 2: | ||
| raise TypeError('More than one arrays must be provided') | ||
| if len(arrays) < 2 and dims is None: | ||
| raise TypeError('dim must be provided for one array computation.') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there's only one array, wouldn't dims just be any repeated dimensions on the single array?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
xarray objects do not have any repeated dimensions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not strictly true: #1378 . That said, we certainly don't support repeated dims well right now.
Even if we banned repeated dimensions, I still think there's no harm in supporting the trivial xr.dot(array) -> array.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK. Updated.
|
|
||
|
|
||
| def dot(*args, **kwargs): | ||
| """ dot(*arrays, dims=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was confused. def dot(*arrays, *, dims=None) is not valid syntax in Python 3, either. (There can only be one single *)
xarray/core/computation.py
Outdated
| common_dims = set(arrays[0].dims) | ||
| all_dims = [] | ||
| for arr in arrays[1:]: | ||
| common_dims = common_dims.intersection(set(arr.dims)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be slightly more efficient to construct common_dims with a single call to intersection?
e.g.,
common_dims = set.intersection(*[set(arr.dims) for arr in arrays])
| if len(kwargs) > 0: | ||
| raise TypeError('Invalid keyward arguments {} are given'.format( | ||
| list(kwargs.keys()))) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens if you write xr.dot()? I suppose we still need to raise an error for 0 arguments.
shoyer
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's wait a little while to see if anyone else has feedback, e.g,. on the name. But this looks very nice to me!
|
|
||
| if any(not isinstance(arr, DataArray) for arr in arrays): | ||
| raise TypeError('Only xr.DataArray and xr.Variable are supported.') | ||
| raise TypeError('Only xr.DataArray and xr.Variable are supported.' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should either update the error message or isinstance() check here -- right now they are inconsistent.
| list(kwargs.keys()))) | ||
|
|
||
| if any(not isinstance(arr, DataArray) for arr in arrays): | ||
| raise TypeError('Only xr.DataArray and xr.Variable are supported.' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Either a type checking or a docstring issue:
In [8]: v=xr.Variable(data=np.random.rand(3,4), dims=('a','b'))
In [9]: xr.dot(v,v)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-9-fac8e1cb222a> in <module>()
----> 1 xr.dot(v,v)
~/drive/workspace/xarray/xarray/core/computation.py in dot(*arrays, **kwargs)
970 if any(not isinstance(arr, DataArray) for arr in arrays):
971 raise TypeError('Only xr.DataArray and xr.Variable are supported.'
--> 972 'Given {}.'.format([type(arr) for arr in arrays]))
973
974 if len(arrays) == 0:
TypeError: Only xr.DataArray and xr.Variable are supported.Given [<class 'xarray.core.variable.Variable'>, <class 'xarray.core.variable.Variable'>].
| raise TypeError('At least one array should be given.') | ||
|
|
||
| if isinstance(dims, basestring): | ||
| dims = (dims, ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW you don't need the parentheses
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I personally like parentheses, as I think it is more descriptive.
xarray/core/computation.py
Outdated
| if isinstance(dims, basestring): | ||
| dims = (dims, ) | ||
| elif isinstance(dims, list): | ||
| dims = tuple(dims) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW dims=tuple(dims) doesn't create any copies if dims is already a tuple, so you could skip the if isinstance check
|
Thanks, @maxim-lian |
|
This is awesome. Beautiful code, immediately impactful, and the API is so simple - a testament to the benefits of named dims Thank you @fujiisoup ! |
|
Do you know why the tests are failing? Do you want me to have a look? The arrays look the same: https://travis-ci.org/pydata/xarray/jobs/350640898#L5182. Would |
|
I just noticed the test failings. |
|
I'm going to merge this tomorrow if there are no further comments. |
whats-new.rstfor all changes andapi.rstfor new API (remove if this change should not be visible to users, e.g., if it is an internal clean-up, or if this is part of a larger project that will be documented later)Currently, lazy-einsum for dask is not yet working.
@shoyer
I think
apply_ufuncsupports lazy computation, but I did not yet figure out how to do this.Can you give me a help?