Skip to content

use __array_function__ on functions outside numpy #13872

@rgommers

Description

@rgommers

I think this is a bug, but not 100% sure. The ndarray.__array_function__ implementation seems special, it's not recognized when applying array_function_dispatch to a function outside of NumPy (which NEP 18 suggests is possible).

To try, I add the following lines to PyTorch at the end of torch.__init__.py:

def _sum_dispatcher(input, dtype=None):
    return (input, dtype)

_sum = sum
@_np.core.overrides.array_function_dispatch(_sum_dispatcher)
def sum(input, dtype=None):
    return _sum(input)  # don't worry about the missing `dtype` here, that's a torch issue

Then, I run the following (on 1.16.4 with the envvar enabled; need to rebuild to try master - EDIT: same for current master):

import numpy as np
import torch
import sparse
import dask.array

t = torch.Tensor([1, 2])
x = t.numpy()
s = sparse.as_coo(x)
d = dask.array.from_array(x)

print("Sum of tensor t: ", torch.sum(t))
print("Sum of dask array d: ", torch.sum(d))
# Okay, let's add a compute()
print("Sum of dask array d (evaluated): ", torch.sum(d).compute())
print("Sum of sparse array s: ", torch.sum(s))
print("Sum of ndarray x: ", torch.sum(x))

This gives:

Sum of tensor t:  tensor(3.)
Sum of dask array d:  dask.array<sum-aggregate, shape=(), dtype=float32, chunksize=()>
Sum of dask array d (evaluated):  3.0
Sum of sparse array s:  3.0
Traceback (most recent call last):
  File "try_torch_array_function.py", line 18, in <module>
    print("Sum of ndarray x: ", torch.sum(x))
  File "/Users/rgommers/anaconda3/envs/pytorch/lib/python3.7/site-packages/numpy/core/overrides.py", line 165, in public_api
    implementation, public_api, relevant_args, args, kwargs)
  File "/Users/rgommers/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/__init__.py", line 326, in sum
    return _sum(input)
TypeError: sum(): argument 'input' (position 1) must be Tensor, not numpy.ndarray

So it works fine with Dask and pydata/sparse, but fails with NumPy - the traceback indicates that the dispatch to numpy.sum is not happening at all. Not expected I think?

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions