Skip to content

Wrap torch.distributions.Normal for use in Pyro#607

Merged
neerajprad merged 5 commits intodevfrom
wrap-torch-distributions
Nov 28, 2017
Merged

Wrap torch.distributions.Normal for use in Pyro#607
neerajprad merged 5 commits intodevfrom
wrap-torch-distributions

Conversation

@fritzo
Copy link
Copy Markdown
Member

@fritzo fritzo commented Nov 28, 2017

Addresses #606

This creates an optional wrapper to use torch.distributions.Normal in Pyro. The torch version is only used if all of the following are satisfied:

  • The environment variable PYRO_USE_TORCH_DISTRIBUTIONS=1 is set
  • The torch.distributions module exists (it is missing in PyTorch 0.2 release)
  • The torch.distributions.Normal class exists
  • All requested features are available (e.g. torch.distributions.Normal is reparameterized if reparameterized=True, also log_pdf_mask is not supported).

If any of the previous conditions are not satisfied, Pyro falls back to the standard implementation.

Tested

The torch distribution will not be exercised on travis. Tests pass locally except for an unrelated bug in PyTorch master 0.4.0a0+709fcfd.


def batch_shape(self, x=None):
x_shape = [] if x is None else x.size()
shape = torch.Size(broadcast_shape(x_shape, self._param_shape, strict=True))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might lead to some hard to find bugs like #414 if the event dimensions do not match between the data and the parameters. Should we allow this kind of broadcasting, or limit it to the batch dimensions only (i.e. any x's rightmost sizes must exactly agree with sample_shape + event_shape)?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that you have added a strict argument which should take care of this.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, that was actually needed to pass some of the expected-error tests.

:rtype: tuple
:raises: ValueError
"""
strict = kwargs.pop('strict', False)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be nice to extend the tests for this utility function, by specifying this as another parameter.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, will do.


def batch_shape(self, x=None):
x_shape = [] if x is None else x.size()
shape = torch.Size(broadcast_shape(x_shape, self._param_shape, strict=True))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that you have added a strict argument which should take care of this.

Copy link
Copy Markdown
Member

@neerajprad neerajprad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

@neerajprad neerajprad merged commit b772145 into dev Nov 28, 2017
@martinjankowiak martinjankowiak deleted the wrap-torch-distributions branch November 29, 2017 23:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants