Skip to content

Generic object to tensor dispatching #17268

@ferrine

Description

@ferrine

🚀 Feature

It is a common issue in extending pytorch for general usage that it's functions allow only tensor inputs. However, some objects, that represent a tensor are better to be stored not as tensor, but as a custom python object.

Motivation

Riemannian Optimization

image

Here is a brief motivating example. Consider a UdV=X matrix decomposition. It has 2 matrices on Stiefel Manifold and a vector with positive entries as a minimal description. One would like to work with this representation as it forms a nice way to represent low rank matrices. It is now impossible to pass custom objects to pytorch functions.

>>> torch.svd(myobject)
Traceback (most recent call last):
  File "<input>", line 1, in <module>
TypeError: svd(): argument 'input' (position 1) must be Tensor, not object

The use case of the above example is interesting for the geoopt project, where I plan to support such kind of manifolds in future. I have no idea how to do it without efforts from pytorch-dev side.

Probabilistic Programming

Another interesting example is picked up from probabilistic programming languages(like pymc3, edward, pyro). Distribution objects describe the belief over all possible values. And a good user API assumes some abstraction there. In pymc3 (with a help of theano) we mixed tensors and distributions together producing things like

import pymc3 as pm

X, y = linear_training_data()
with pm.Model() as linear_model:
    weights = pm.Normal('weights', mu=0, sd=1)
    noise = pm.Gamma('noise', alpha=2, beta=1)
    y_observed = pm.Normal('y_observed',
                mu=X.dot(weights),
                sd=noise,
                observed=y)

    prior = pm.sample_prior_predictive()
    posterior = pm.sample()
    posterior_pred = pm.sample_posterior_predictive(posterior)

You see, calling a distribution object forms a tensor compatible object that can be used in the downstream. We used a lot of hacks with __new__ there to be honest due to the absence of API like tensorflow provides:
https://www.tensorflow.org/api_docs/python/tf/register_tensor_conversion_function

We also used this particular one in pymc4 design as following. In a different context that might be inference or mcmc sampling it changes the behaviour producing either conditional or unconditional distribution.
https://github.com/pymc-devs/pymc4/blob/master/pymc4/random_variables/random_variable.py#L119

Pyro as I know still relies on explicit .sample(...) calls

def scale(guess):
    weight = pyro.sample("weight", dist.Normal(guess, 1.0))
    return pyro.sample("measurement", dist.Normal(weight, 0.75))

And I expect, pyro developers would be happy to join this discussion


So you see how much benefit can be acquired with tensor conversion functions. It allows 3d party developers build very neat pytorch compatible API improving user experience.

Pitch

There should be a convention for a developer API to register tensor conversion functions like tensorflow does (they have a nice one API at this point)

This is the tensorflow convension

tf.register_tensor_conversion_function(
    base_type,
    conversion_func,
    priority=100
)
def conversion_func(value, dtype=None, name=None, as_ref=False):
      ...

where

  • base_type: The base type or tuple of base types for all objects that conversion_func accepts.
  • conversion_func: A function that converts instances of base_type to Tensor.
  • priority: Optional integer that indicates the priority for applying this conversion function. Conversion functions with smaller priority values run earlier than conversion functions with larger priority values. Defaults to 100.
    (full page)

Pytorch does not use names, so may omit some of this. as_ref argument is about view or copy return type. An object may own the tensor and may optionally provide itself to perform inplace operations (if passed to out=myobject, for example).

Pytorch also has some so called tensor options that include dtype, device, stride, etc. Thus convensions from Tensor creation guide may apply as a better api.

So this may look like, finally

torch.register_tensor_conversion_function(
    base_type,
    conversion_func,
    priority=100
)

def conversion_func(
    value,
    # should either inherit from value variable created, or set its own option
    dtype=None,
    device=None,
    stride=None, 
    requires_grad=None
):
      ...

Shape inference might be nice as well (if tensor creation is a heavy operation) to check if a resulting tensor is compatible with other arguments. This implies

torch.register_tensor_shape_conversion_function(
    base_type,
    conversion_func,
    priority=100
)
def conversion_func(
    value
):
    ...

Shape conversion does not seem to be compulsory as can be inferred with register_tensor_conversion_function, but is a good way to debug faster

CC @soumith, @fritzo

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureA request for a proper, new feature.triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions