Skip to content

Wrap torch.distributions.Gamma, Beta#630

Merged
neerajprad merged 18 commits intodevfrom
wrap-torch-distributions
Dec 21, 2017
Merged

Wrap torch.distributions.Gamma, Beta#630
neerajprad merged 18 commits intodevfrom
wrap-torch-distributions

Conversation

@fritzo
Copy link
Copy Markdown
Member

@fritzo fritzo commented Dec 15, 2017

This refactors the torch.distribution wrapping layer and adds a new TorchDistribution class to implement common wrapper methods. The new class hierarchy is:

  • pyro.distributions.Distribution
    • pyro.distributions.Normal
    • ...
    • pyro.distributions.torch_wrapper.TorchDistribution
      • pyro.distributions.torch.Normal -> has a torch.distributions.Normal
      • ...
    • pyro.distributions.RandomPrimitive -> has a pyro.distributions.Distribution

Pytorch implementations are now available side-by-side with older Pyro implementations, e.g,

from pyro.distributions import normal as old_normal
from pyro.distributions.torch import normal as new_normal

Tested

Ran make test-torch-dist against the test-rsample branch of pytorch (which implements Beta).

@fritzo
Copy link
Copy Markdown
Member Author

fritzo commented Dec 15, 2017

@martinjankowiak torch.digamma is not yet in PyTorch master (pytorch/pytorch#3955), so the Gamma in this PR weirdly has reparamterized gradients but a non-differentiable score function 🙃

@neerajprad
Copy link
Copy Markdown
Member

@fritzo - In light of what we discussed yesterday, are we looking to get this PR merged? I was thinking that once we start testing against PyTorch master (or some pinned version of master), we can just remove our implementation code and instead wrap around torch.distribution classes directly, i.e. TorchNormal --> Normal.

@fritzo
Copy link
Copy Markdown
Member Author

fritzo commented Dec 20, 2017

are we looking to get this PR merged?

Yes, I just haven't had a chance to update this PR in light of yesterday's #638 and #640 .

@fritzo
Copy link
Copy Markdown
Member Author

fritzo commented Dec 20, 2017

@neerajprad @martinjankowiak could you take another look? I've refactored to support side-by-side use of Pyro and PyTorch distributions:

from pyro.distributions import Normal as OldNormal
from pyro.distributions.torch import Normal as NewNormal

The PYRO_USE_TORCH_DISTRIBUTIONS environment variable is needed only in testing; at some point we can flip the switch to use those distributions by default.

It would be nice to merge this soon to facilitate experimentation with the PyTorch distributions.

USE_TORCH_DISTRIBUTIONS = int(os.environ.get('PYRO_USE_TORCH_DISTRIBUTIONS', 0))

# distribution classes with working torch versions
if USE_TORCH_DISTRIBUTIONS:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

from pyro.distributions.torch.normal import Normal

# function aliases
beta = RandomPrimitive(Beta)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to expose these functional forms here too (since they should get initialized in distributions/__init__.py)?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's convenient to expose those here for side-by-side testing (as Martin is doing today):

from pyro.distributions import normal as old_normal
from pyro.distributions.torch import normal as new_normal

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or

import pyro.distributions as old_dist
import pyro.distributions.torch as new_dist

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks for explaining.

Copy link
Copy Markdown
Collaborator

@martinjankowiak martinjankowiak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm (but didn't look at anything apart from high-level stuff)

Comment thread pyro/distributions/util.py Outdated
try:
source_attr = getattr(source_class, name)
destin_attr = getattr(destin_class, name)
destin_attr.__doc__ = source_attr.__doc__
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check if destin_attr.__doc__ already exists and only change if it doesn't, so that the subclasses don't have the docstrings replaced by the base class?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, I'll do that.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

neerajprad
neerajprad previously approved these changes Dec 20, 2017
Copy link
Copy Markdown
Member

@neerajprad neerajprad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really nice, simplifies a lot of the wrapping logic!

@fritzo fritzo dismissed stale reviews from neerajprad and martinjankowiak via 908cd8a December 20, 2017 23:48


def torch_wrapper(pyro_dist):
def copy_docs_from(source_class):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

heh this is cute. are we eventually going to migrate them completely so theres only one distribution? will it be pyro.distributions or pyro.distributions.torch?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're planning to eventually migrate docs. Thereafter we can still use this decorator to copy docs from Distributions to instance classes.

reparameterized = True

def __init__(self, mu, sigma, *args, **kwargs):
torch_dist = torch.distributions.Normal(mu, sigma)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so from what i can see, this returns the pt distribution which only provides sample and score methods. but batch and event shapes eventually be pushed upstream as well, correct?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have batch and event shapes in torch.distributions, but they are internal attributes and not exposed as an API. Even in Pyro, we only need to use it, outside of the distribution instances, for testing. We still need to add analytical mean/variance to torch.distributions classes; otherwise, they should already be at feature parity.

@neerajprad neerajprad merged commit 9ba8838 into dev Dec 21, 2017
@fritzo fritzo deleted the wrap-torch-distributions branch December 21, 2017 01:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants