-
Notifications
You must be signed in to change notification settings - Fork 27.4k
Description
🚀 Feature
It is a common issue in extending pytorch for general usage that it's functions allow only tensor inputs. However, some objects, that represent a tensor are better to be stored not as tensor, but as a custom python object.
Motivation
Riemannian Optimization
Here is a brief motivating example. Consider a UdV=X matrix decomposition. It has 2 matrices on Stiefel Manifold and a vector with positive entries as a minimal description. One would like to work with this representation as it forms a nice way to represent low rank matrices. It is now impossible to pass custom objects to pytorch functions.
>>> torch.svd(myobject)
Traceback (most recent call last):
File "<input>", line 1, in <module>
TypeError: svd(): argument 'input' (position 1) must be Tensor, not objectThe use case of the above example is interesting for the geoopt project, where I plan to support such kind of manifolds in future. I have no idea how to do it without efforts from pytorch-dev side.
Probabilistic Programming
Another interesting example is picked up from probabilistic programming languages(like pymc3, edward, pyro). Distribution objects describe the belief over all possible values. And a good user API assumes some abstraction there. In pymc3 (with a help of theano) we mixed tensors and distributions together producing things like
import pymc3 as pm
X, y = linear_training_data()
with pm.Model() as linear_model:
weights = pm.Normal('weights', mu=0, sd=1)
noise = pm.Gamma('noise', alpha=2, beta=1)
y_observed = pm.Normal('y_observed',
mu=X.dot(weights),
sd=noise,
observed=y)
prior = pm.sample_prior_predictive()
posterior = pm.sample()
posterior_pred = pm.sample_posterior_predictive(posterior)You see, calling a distribution object forms a tensor compatible object that can be used in the downstream. We used a lot of hacks with __new__ there to be honest due to the absence of API like tensorflow provides:
https://www.tensorflow.org/api_docs/python/tf/register_tensor_conversion_function
We also used this particular one in pymc4 design as following. In a different context that might be inference or mcmc sampling it changes the behaviour producing either conditional or unconditional distribution.
https://github.com/pymc-devs/pymc4/blob/master/pymc4/random_variables/random_variable.py#L119
Pyro as I know still relies on explicit .sample(...) calls
def scale(guess):
weight = pyro.sample("weight", dist.Normal(guess, 1.0))
return pyro.sample("measurement", dist.Normal(weight, 0.75))And I expect, pyro developers would be happy to join this discussion
So you see how much benefit can be acquired with tensor conversion functions. It allows 3d party developers build very neat pytorch compatible API improving user experience.
Pitch
There should be a convention for a developer API to register tensor conversion functions like tensorflow does (they have a nice one API at this point)
This is the tensorflow convension
tf.register_tensor_conversion_function(
base_type,
conversion_func,
priority=100
)
def conversion_func(value, dtype=None, name=None, as_ref=False):
...where
- base_type: The base type or tuple of base types for all objects that conversion_func accepts.
- conversion_func: A function that converts instances of base_type to Tensor.
- priority: Optional integer that indicates the priority for applying this conversion function. Conversion functions with smaller priority values run earlier than conversion functions with larger priority values. Defaults to 100.
(full page)
Pytorch does not use names, so may omit some of this. as_ref argument is about view or copy return type. An object may own the tensor and may optionally provide itself to perform inplace operations (if passed to out=myobject, for example).
Pytorch also has some so called tensor options that include dtype, device, stride, etc. Thus convensions from Tensor creation guide may apply as a better api.
So this may look like, finally
torch.register_tensor_conversion_function(
base_type,
conversion_func,
priority=100
)
def conversion_func(
value,
# should either inherit from value variable created, or set its own option
dtype=None,
device=None,
stride=None,
requires_grad=None
):
...Shape inference might be nice as well (if tensor creation is a heavy operation) to check if a resulting tensor is compatible with other arguments. This implies
torch.register_tensor_shape_conversion_function(
base_type,
conversion_func,
priority=100
)
def conversion_func(
value
):
...Shape conversion does not seem to be compulsory as can be inferred with register_tensor_conversion_function, but is a good way to debug faster
