Wrap torch.distributions.Gamma, Beta#630
Conversation
|
@martinjankowiak |
|
@fritzo - In light of what we discussed yesterday, are we looking to get this PR merged? I was thinking that once we start testing against PyTorch master (or some pinned version of master), we can just remove our implementation code and instead wrap around |
|
@neerajprad @martinjankowiak could you take another look? I've refactored to support side-by-side use of Pyro and PyTorch distributions: from pyro.distributions import Normal as OldNormal
from pyro.distributions.torch import Normal as NewNormalThe It would be nice to merge this soon to facilitate experimentation with the PyTorch distributions. |
| USE_TORCH_DISTRIBUTIONS = int(os.environ.get('PYRO_USE_TORCH_DISTRIBUTIONS', 0)) | ||
|
|
||
| # distribution classes with working torch versions | ||
| if USE_TORCH_DISTRIBUTIONS: |
| from pyro.distributions.torch.normal import Normal | ||
|
|
||
| # function aliases | ||
| beta = RandomPrimitive(Beta) |
There was a problem hiding this comment.
Do we need to expose these functional forms here too (since they should get initialized in distributions/__init__.py)?
There was a problem hiding this comment.
It's convenient to expose those here for side-by-side testing (as Martin is doing today):
from pyro.distributions import normal as old_normal
from pyro.distributions.torch import normal as new_normalThere was a problem hiding this comment.
or
import pyro.distributions as old_dist
import pyro.distributions.torch as new_distThere was a problem hiding this comment.
I see, thanks for explaining.
martinjankowiak
left a comment
There was a problem hiding this comment.
lgtm (but didn't look at anything apart from high-level stuff)
| try: | ||
| source_attr = getattr(source_class, name) | ||
| destin_attr = getattr(destin_class, name) | ||
| destin_attr.__doc__ = source_attr.__doc__ |
There was a problem hiding this comment.
Should we check if destin_attr.__doc__ already exists and only change if it doesn't, so that the subclasses don't have the docstrings replaced by the base class?
neerajprad
left a comment
There was a problem hiding this comment.
This is really nice, simplifies a lot of the wrapping logic!
908cd8a
|
|
||
|
|
||
| def torch_wrapper(pyro_dist): | ||
| def copy_docs_from(source_class): |
There was a problem hiding this comment.
heh this is cute. are we eventually going to migrate them completely so theres only one distribution? will it be pyro.distributions or pyro.distributions.torch?
There was a problem hiding this comment.
We're planning to eventually migrate docs. Thereafter we can still use this decorator to copy docs from Distributions to instance classes.
| reparameterized = True | ||
|
|
||
| def __init__(self, mu, sigma, *args, **kwargs): | ||
| torch_dist = torch.distributions.Normal(mu, sigma) |
There was a problem hiding this comment.
so from what i can see, this returns the pt distribution which only provides sample and score methods. but batch and event shapes eventually be pushed upstream as well, correct?
There was a problem hiding this comment.
We have batch and event shapes in torch.distributions, but they are internal attributes and not exposed as an API. Even in Pyro, we only need to use it, outside of the distribution instances, for testing. We still need to add analytical mean/variance to torch.distributions classes; otherwise, they should already be at feature parity.
This refactors the
torch.distributionwrapping layer and adds a newTorchDistributionclass to implement common wrapper methods. The new class hierarchy is:pyro.distributions.Distributionpyro.distributions.Normalpyro.distributions.torch_wrapper.TorchDistributionpyro.distributions.torch.Normal-> has atorch.distributions.Normalpyro.distributions.RandomPrimitive-> has apyro.distributions.DistributionPytorch implementations are now available side-by-side with older Pyro implementations, e.g,
Tested
Ran
make test-torch-distagainst the test-rsample branch of pytorch (which implementsBeta).