Skip to content

Refactor TorchDistribution and wrap torch.distributions.Bernoulli#645

Merged
neerajprad merged 2 commits intodevfrom
wrap-torch-bernoulli
Dec 21, 2017
Merged

Refactor TorchDistribution and wrap torch.distributions.Bernoulli#645
neerajprad merged 2 commits intodevfrom
wrap-torch-bernoulli

Conversation

@fritzo
Copy link
Copy Markdown
Member

@fritzo fritzo commented Dec 21, 2017

This adds a working implementation of pyro.distributions.torch.Bernoulli. It also simplifies individual torch wrappers by moving the .batch_shape() and .event_shape() methods up to the parent class TorchDistribution.

Tested

  • Ran make test-torch-dist against the test-rsample branch of PyTorch.

@fritzo
Copy link
Copy Markdown
Member Author

fritzo commented Dec 21, 2017

@neerajprad After this and #647 are merged, I'll address Categorical

ps = F.sigmoid(logits)
eps = get_clamping_buffer(ps)
ps = ps.clamp(min=eps, max=1-eps)
torch_dist = torch.distributions.Bernoulli(ps)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do any tests fail if we do not do this clamping here?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, your _flow tests fail. I'm glad you implemented them.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh..I see. Makes sense! Will add logit support in PyTorch soon.

return self._param_shape[-1:]
x_shape = torch.Size(broadcast_shape(alpha.size(), beta.size(), strict=True))
event_dim = 1
super(Gamma, self).__init__(torch_dist, x_shape, event_dim, *args, **kwargs)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's good to see all of this getting absorbed into the wrapper class!

Copy link
Copy Markdown
Member

@neerajprad neerajprad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! I am merging this.

@neerajprad neerajprad merged commit f5a51fe into dev Dec 21, 2017
@fritzo
Copy link
Copy Markdown
Member Author

fritzo commented Dec 21, 2017

Thanks for reviewing!

@fritzo fritzo deleted the wrap-torch-bernoulli branch January 5, 2018 16:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants