-
-
Notifications
You must be signed in to change notification settings - Fork 12.2k
use __array_function__ on functions outside numpy #13872
Copy link
Copy link
Open
Description
I think this is a bug, but not 100% sure. The ndarray.__array_function__ implementation seems special, it's not recognized when applying array_function_dispatch to a function outside of NumPy (which NEP 18 suggests is possible).
To try, I add the following lines to PyTorch at the end of torch.__init__.py:
def _sum_dispatcher(input, dtype=None):
return (input, dtype)
_sum = sum
@_np.core.overrides.array_function_dispatch(_sum_dispatcher)
def sum(input, dtype=None):
return _sum(input) # don't worry about the missing `dtype` here, that's a torch issue
Then, I run the following (on 1.16.4 with the envvar enabled; need to rebuild to try master - EDIT: same for current master):
import numpy as np
import torch
import sparse
import dask.array
t = torch.Tensor([1, 2])
x = t.numpy()
s = sparse.as_coo(x)
d = dask.array.from_array(x)
print("Sum of tensor t: ", torch.sum(t))
print("Sum of dask array d: ", torch.sum(d))
# Okay, let's add a compute()
print("Sum of dask array d (evaluated): ", torch.sum(d).compute())
print("Sum of sparse array s: ", torch.sum(s))
print("Sum of ndarray x: ", torch.sum(x))
This gives:
Sum of tensor t: tensor(3.)
Sum of dask array d: dask.array<sum-aggregate, shape=(), dtype=float32, chunksize=()>
Sum of dask array d (evaluated): 3.0
Sum of sparse array s: 3.0
Traceback (most recent call last):
File "try_torch_array_function.py", line 18, in <module>
print("Sum of ndarray x: ", torch.sum(x))
File "/Users/rgommers/anaconda3/envs/pytorch/lib/python3.7/site-packages/numpy/core/overrides.py", line 165, in public_api
implementation, public_api, relevant_args, args, kwargs)
File "/Users/rgommers/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/__init__.py", line 326, in sum
return _sum(input)
TypeError: sum(): argument 'input' (position 1) must be Tensor, not numpy.ndarray
So it works fine with Dask and pydata/sparse, but fails with NumPy - the traceback indicates that the dispatch to numpy.sum is not happening at all. Not expected I think?
Reactions are currently unavailable