Wrap torch.distributions.Normal for use in Pyro#607
Conversation
|
|
||
| def batch_shape(self, x=None): | ||
| x_shape = [] if x is None else x.size() | ||
| shape = torch.Size(broadcast_shape(x_shape, self._param_shape, strict=True)) |
There was a problem hiding this comment.
This might lead to some hard to find bugs like #414 if the event dimensions do not match between the data and the parameters. Should we allow this kind of broadcasting, or limit it to the batch dimensions only (i.e. any x's rightmost sizes must exactly agree with sample_shape + event_shape)?
There was a problem hiding this comment.
I see that you have added a strict argument which should take care of this.
There was a problem hiding this comment.
Right, that was actually needed to pass some of the expected-error tests.
| :rtype: tuple | ||
| :raises: ValueError | ||
| """ | ||
| strict = kwargs.pop('strict', False) |
There was a problem hiding this comment.
It will be nice to extend the tests for this utility function, by specifying this as another parameter.
|
|
||
| def batch_shape(self, x=None): | ||
| x_shape = [] if x is None else x.size() | ||
| shape = torch.Size(broadcast_shape(x_shape, self._param_shape, strict=True)) |
There was a problem hiding this comment.
I see that you have added a strict argument which should take care of this.
Addresses #606
This creates an optional wrapper to use
torch.distributions.Normalin Pyro. The torch version is only used if all of the following are satisfied:PYRO_USE_TORCH_DISTRIBUTIONS=1is settorch.distributionsmodule exists (it is missing in PyTorch 0.2 release)torch.distributions.Normalclass existstorch.distributions.Normalis reparameterized ifreparameterized=True, alsolog_pdf_maskis not supported).If any of the previous conditions are not satisfied, Pyro falls back to the standard implementation.
Tested
The torch distribution will not be exercised on travis. Tests pass locally except for an unrelated bug in PyTorch master 0.4.0a0+709fcfd.