Skip to content

PyTorch Tensor subclasses and protocols for NumPy interoperability #22402

@rgommers

Description

@rgommers

🚀 Feature

This is a description of several related features that are best considered together.

  1. Allow subclassing Tensor and propagating subclass instances correctly with torch functions, operators, using views/slices/etc.
  2. Support the NumPy array protocols
  3. Allow other libraries to reuse the PyTorch API via a similar method as NumPy uses

Motivation

This issue/document is motivated by the attempts in PyTorch PR 17249 and follow-up PRs 18610, 22235 and 22247 to make torch.Tensor subclasses interact better with NumPy and Torch functions. Currently (June 2019), Tensor subclassing is not yet supported and while PyTorch in many cases follows the NumPy API, direct interoperability is limited (instead one needs to explicitly convert between torch.Tensor and numpy.ndarray).

Potential goals

These are potential goals that have been collected from the above referenced PRs, other PyTorch issues (referenced in the relevant sections), as well as from discussions with mainly Edward Yang, and also other PyTorch and NumPy maintainers:

  1. Support subclassing torch.Tensor in Python
  2. Preserve Tensor subclasses when calling torch functions on them
  3. Preserve Tensor subclasses when calling numpy functions on them
  4. Use the NumPy API with PyTorch tensors (i.e. NumPy API calls dispatch to torch functions)
  5. Use the PyTorch API with torch.Tensor-like objects that are not Tensor subclasses
  6. Reuse NumPy ufunc implementations directly from PyTorch
  7. Allow operations on mixed array types, e.g. tensor + ndarray

Important to keep in mind when implementing features that achieve any of the above goals:

  • The PyTorch team is planning more complex Tensor wrappers, an effort that should not be made significantly more difficult.
  • The PyTorch team may want to provide a (g)ufunc-like mechanism in PyTorch in the future. Also that should not be made unnecessarily complex.

Support subclassing torch.Tensor in Python

Note that Tensor seems to have been designed with subclassing in mind, at least that's what the comments at https://pytorch.org/docs/stable/_modules/torch/tensor.html (NB: If you subclass Tensor ...) indicate. Support seems incomplete though. The most basic way of subclassing just adds some new attributes, e.g. to carry around specific metadata like "this tensor represents voltages".

class AddedAttributeTensor(torch.Tensor):
    data_info = 'voltage'

t1 = torch.Tensor([1, 2])
t2 = AddedAttributeTensor([3, 6])

print("Extra attribute of subclass: ", t2.extra_attr)
print("Tensor + subclass gives back a Tensor instance: ", t1 + t2)
print("Is subclass preserved for operators?  ", isinstance(t2 + t2, AddedAttributeTensor))
print("Does slicing preserve subclass?  ", isinstance(t2[:1], AddedAttributeTensor))
print("Does taking a view preserve subclass?  ", isinstance(t2.view((2, 1)), AddedAttributeTensor))

Running this code shows that for a regular subclass, the subclass doesn't propagate (we always get a plain Tensor instance back:

Extra attribute of subclass:  voltage
Tensor + subclass gives back a Tensor instance:  tensor([4., 8.])
Is subclass preserved for operators?   False
Does slicing preserve subclass?   False
Does taking a view preserve subclass?   False

The NumPy subclassing docs discuss the ways in which new class instances can be created; the same will apply to Tensor instances. To deal with that, two methods are needed:

  1. A __new__ method for initialization in case of an explicit constructor call.
  2. A method to deal with creation in other ways, like slicing or taking a view (which will bypass __new__).

For (2) NumPy uses __array_finalize__, however in gh-22247 the claim was that this method is very expensive (because it gets called too often - this doesn't seem the case though, see the "Performance considerations" section further down). Instead it introduced a THPVariable_result_ptype, which achieves the same thing (although it was confusingly mixed up with an incorrect use of __array_ufunc__ there).

Also note that torch.Tensor._make_subclass already exists (defined in torch/csrc/autograd/python_variable.cpp, according to the comment specifically for use with torch.nn.Parameter). It's unclear whether that or the code in gh-22235 works for views and slicing; it's not tested.

Preserve Tensor subclasses when calling torch functions on them

Note that this was the goal of gh-22235, which is useful as a reference.

For Tensor subclasses that do not implement __torch_function__ (assuming that gets implemented, see goal 5), this will work if the __array_finalize__ equivalent gets implemented (see previous section). For subclasses that do implement __torch_function__, all torch functions get overridden by the subclass, so it has more control over this (although in many cases it will still make use of the __array_finalize__ equivalent).

Preserve Tensor subclasses when calling numpy functions on them

Note that this was the goal of gh-22247, which is useful as a reference.

This should be done via implementation of __array_ufunc__ and __array_function__ on the Tensor class. At that point, all the NumPy functions that have a PyTorch equivalent will work (including subclass propagation if that's implemented for torch functions and operators), and other NumPy functions will error.

Use the NumPy API with PyTorch tensors (i.e. NumPy API calls dispatch to torch functions)

NumPy provides two protocols that ndarray-like objects (like torch.Tensor) can implement to make NumPy API calls dispatch to their own implementations. Those protocols are __array_ufunc__ (available since NumPy 1.13) and __array_function__ (available since NumPy 1.17 by default; in 1.16 one can enable it via an environment variable). These two protocols work in the same way:

  1. Pass Tensor instance to a numpy function (e.g. numpy.abs)
  2. NumPy detects the presence of Tensor.__array_ufunc__ (or Tensor.__array_function__) and delegates execution to it.
  3. The Tensor.__array_ufunc__ implementation then can forward that function call to the right implementation (torch.abs or Tensor.abs).

The main benefit of implementing these functions is that users can prototype new code with NumPy, or reuse their existing code, and that code will then work unchanged when passing in PyTorch tensors (even if they live on a GPU). For more context, see e.g. NEP 18. Also note that CuPy, Dask and pydata/sparse already implement these protocols.

Note that there's a related discussion at gh-2228, "PyTorch with NumPy syntax?", with a fairly detailed plan to provide a new torch.np API. That is very much related. That plan does seem less desirable than using __array_function__ - why create a whole new API in a torch.np submodule when it's now possible to use the NumPy API itself?

There may be a backwards compatibility issue here. Because torch.Tensor already implements
__array__ and __array_wrap__, many (but not all) NumPy functions will already work with Tensor:

In [1]: import torch                                                                           

In [2]: t = torch.Tensor([1, -2])                                                              

In [3]: np.abs(t)                                                                              
Out[3]: tensor([1., 2.])

In [4]: np.sin(t)                                                                              
Out[4]: tensor([ 0.8415, -0.9093])

In [5]: np.dot(t, t)                                                                           
Out[5]: 5.0

In [6]: torch.dot(t, t)  # would be called if t had __array_function__                         
Out[6]: tensor(5.)

In [7]: np.mean(t)  # not all functions work ....                                              
...
TypeError: mean() missing 3 required positional argument: "dim", "keepdim", "dtype"

So here the return from np.dot(t, t) would change from 5.0 to tensor(5.). For functions in NumPy that don't have a PyTorch equivalent, the PyTorch __array_function__ implementation should explicitly convert to ndarray with np.asarray and then call the NumPy functions (this preserves current behavior). Returning NotImplemented will cause NumPy to raise a TypeError while those functions worked previously via __array__.
Likely this is not a major issue (Dask, CuPy and pydata/sparse all didn't consider this problematic), but it's good to explicitly think about this. The Partial implementation of NumPy's API section of NEP 18 provides a detailed discussion on this point.

A note of caution is probably warranted here: while __array_ufunc__ has been around for over 2 years and has generally worked very well, __array_function__ (which does work very similarly but has to deal with more flexible function signatures) is brand new. An alternative discussed in NEP 18 is to use multiple dispatch. That would be a more comprehensive solution (one can override anything, see e.g. uarray), however it's a more invasive change with likely larger overhead (3-5 function calls rather than 1). Adding a protocol now would not preclude adding a multiple dispatch layer on top later. If the larger overhead is acceptable though, the PyTorch team could also decide that a more complete multiple dispatch layer (perhaps in a separate namespace or project) would be the better solution.

Use the PyTorch API with torch.Tensor-like objects that are not Tensor subclasses

This would allow users to write their own tensor implementations and have users use the familiar PyTorch API with it. It can be implemented with a __torch_function__ protocol, which would work analogously to the NumPy __array_function__ protocol.

Providing such a __torch_function__ protocol will also help Tensor subclasses to modify the behavior of individual torch functions, while forwarding directly to the torch functions that that subclass does not want to modify (for an example of how this can work see the __array_ufunc__ section of the NumPy subclassing docs).

In issue 17268, "Generic object to tensor dispatching", there's both a proposal to API to register tensor conversion functions and a response from the Pyro developers that they'd prefer support for __array_function__.

As an alternative, the __array_function__ protocol could be used directly on torch functions. The way to reuse __array_function__ would be to decorate functions in PyTorch with @array_function_dispatch (from numpy.core.overrides, which is currently still private). The upsides of that are that the mechanism exists already, and is already supported by other array libraries. The potential downsides are that:

  • the mechanism is still very new and marked in the NumPy docs as "may still change" (although changes other than bug fixes are unlikely). Therefore, vendoring the decorator would be the way to go.
  • it puts a more stringent requirement on functions in the torch namespace to be compatible in signature with the numpy ones - this is desirable, but may not be achievable due to backwards compatibility reasons.
  • it's still not clear if __array_function__ is actually supposed to be used like this (likely yes, but still under discussion in NumPy gh-13872)

In summary: it's probably better to go with __torch_function__ that works the same way as __array_function__ but has its own "domain". That requires other libraries that want to reuse the PyTorch API to implement __torch_function__ to explicitly opt in.

Reuse NumPy ufunc implementations directly from PyTorch

What the rationale for this goal would be is a little unclear. The number of NumPy functions that are ufuncs and are not already covered by equivalent PyTorch functionality is not that large (see a list of NumPy ufuncs here). The implementation of this feature in gh-22247 was partially motivated by the goal of using Tensor subclasses with NumPy functions; that is best done differently though.

In case users want to use NumPy functions (not just ufuncs) with Tensors today, this either may already work (it does for many functions) or can be done with explicit conversion:

x = tensor.numpy()  # to a numpy array (no-copy if tensor lives in CPU memory)
y = np.somefunction(x)  # use with NumPy
tensor2 = torch.from_numpy(y)  # convert back to a tensor

Adding extra design complexity to avoid these explicit casts does not seem worthwhile. In case there are NumPy functions that are popular, adding those functions to PyTorch itself seems like a better option, especially because that would work for tensors that live in GPU memory as well and be more performant.

The explicit casts would be a little more cumbersome for subclasses, but that's probably not a good enough reason to add design complexity.

Allow operations on mixed array types

In gh-22247 it was suggested that a good goal could be to make operators on mixed array/tensor types work better. Currently torch.Tensor + numpy.ndarray will call the PyTorch implementation of the + operator, while numpy.ndarray + torch.Tensor will call the NumPy implementation of +. This is simply the way Python operator support works, and is unaffected by any of the array protocols like __array_function__. The general advice should be "don't do that" - it's better to be explicit there and convert both left-hand and right-hand side of any expression to either Tensors or ndarrays.

Performance considerations

The extra overhead of the __array_function__ and __array_ufunc__ protocols is that of a single Python function call. Typically that is 300-400 ns (see e.g. numpy/numpy#12830 (comment) for benchmarks). Adding a __torch_function__ protocol should give a similar extra overhead for calling torch.somefunction if a new check needs to be added. However it's likely (says @ezyang) that a check and fast path for torch.Tensor input already exists - in that case there will be no extra overhead for Tensor input (and 300 ns for non-Tensor input seems less of an issue). needs investigating

For comparison, the current overhead a NumPy ufunc including __array_ufunc__ is of the same order (~400 ns) while the overhead of torch functions is significantly larger, ~3 us:

In [1]: import torch                                                                           
t
In [2]: t = torch.Tensor([1, -2])                                                              

In [3]: %timeit torch.abs(t)                                                                   
3.2 µs ± 30.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

In [4]: x = t.numpy()                                                                          

In [5]: %timeit np.abs(x)                                                                      
392 ns ± 5.83 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

That implies that this overhead should be acceptable.

In the PRs that triggered this subclassing discussion, it was said that __array_finalize__ has too much overhead so cannot be used (hence the alternative in gh-22247). To check this, let's implement a similar small array (taken from here, see that description for more extensive comments):

class InfoArray(np.ndarray):
    def __new__(subtype, shape, dtype=float, buffer=None, offset=0,
                strides=None, order=None, info=None):
        # Create the ndarray instance of our type, given the usual
        # ndarray input arguments.  This will call the standard
        # ndarray constructor, but return an object of our type.
        # It also triggers a call to InfoArray.__array_finalize__
        obj = super(InfoArray, subtype).__new__(subtype, shape, dtype,
                                                buffer, offset, strides,
                                                order)
        obj.info = info
        return obj

    def __array_finalize__(self, obj):
        # ``self`` is a new object resulting from
        # ndarray.__new__(InfoArray, ...), therefore it only has
        # attributes that the ndarray.__new__ constructor gave it -
        # i.e. those of a standard ndarray.
        
        if obj is None: return
        # Note that it is here, rather than in the __new__ method,
        # that we set the default value for 'info', because this
        # method sees all creation of default objects
        self.info = getattr(obj, 'info', None)

n = 3
x = np.arange(n)
i = InfoArray(shape=(n,), dtype=np.int64, buffer=x)

Now to test the performance:

In [2]: %timeit x + x                                                                          
436 ns ± 0.616 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

In [3]: %timeit x + i                                                                          
1.43 µs ± 10.8 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

In [4]: %timeit i + i                                                                          
2.4 µs ± 20.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

And for n = 30000:

In [6]: %timeit x + x                                                                          
10.9 µs ± 27.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

In [7]: %timeit x + i                                                                          
12.5 µs ± 29.9 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

In [8]: %timeit i + i                                                                          
13.8 µs ± 439 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

So the extra overhead of __array_finalize__ is ~2-3 us when implemented in pure Python (and a subclass author could decide to implement the method in C if that's a problem). There does seem to be a small design issue in NumPy, because x + i and i + i both require a single new subclass instance to be created, hence for operations of the same type i + i should not be more expensive (but this is a minor detail).

Some comments on the NumPy & PyTorch APIs

  • It's not desirable to copy all of the NumPy API; that API is way too large and many functions are of limited interest or have better alternatives.
  • See RNumPy for a work-in-progress attempt to define a sensible subset of the full NumPy API that other array libraries can target.
  • The PyTorch maintainers have expressed interest/willingness in adding functions to the torch namespace. Focusing first on matching the signatures of functions in the torch and methods in the torch.Tensor namespace would be nice. Right now the functions that are there often have different signatures (e.g. compare torch.sum and Tensor.sum).

Pitch / Possible plan forward

Implement the following (these don't depend on each other, no order implied):

  1. __array_ufunc__ and __array_function__ (small backwards compat impact, no performance impact)
  2. __torch_function__ (no backwards compat impact, likely no performance impact for Tensor input and 300-400 ns for non-Tensor input, needs investigating)
  3. A subclass finalization method (as in gh-22235 or __array_finalize__) (no backwards compat impact, 300 ns - 3 us performance impact for subclasses only)

This would close:

  • gh-2228, "PyTorch with numpy syntax?"
  • gh-20073, "numpy arg translation proof of concept"
  • gh-17249: "Proposal: Add __tensor_wrap__ method similar to numpy __array_wrap__"
  • gh-22235: "ptype propagation on torch functions"
  • gh-22247: "ptype propagation on numpy functions"

Metadata

Metadata

Assignees

Labels

featureA request for a proper, new feature.module: numpyRelated to numpy support, and also numpy compatibility of our operatorstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions