Adopt strict batch shape semantics for distributions#806
Adopt strict batch shape semantics for distributions#806martinjankowiak merged 51 commits intodevfrom
Conversation
|
@neerajprad I've merged dev with your #824 into this PR. Lots of tests still fail. Feel free to push commits to this branch. I'll chat you if I start working on it again (at earliest Wednesday). |
| '- .permute() data dimensions'])) | ||
|
|
||
| # Check parallel dimensions on the left of max_iarange_nesting. | ||
| # TODO |
There was a problem hiding this comment.
@fritzo : Do you want to add some more checks in this PR?
There was a problem hiding this comment.
No, they should be added in later PRs.
| sigma = torch.pow(self.tau, -0.5) | ||
| pyro.observe("obs0", dist.LogNormal(mu_latent, sigma), obs=self.data[0]) | ||
| pyro.observe("obs1", dist.LogNormal(mu_latent, sigma), obs=self.data[1]) | ||
| pyro.observe("obs0", dist.LogNormal(mu_latent, sigma), obs=self.data[0].squeeze()) |
There was a problem hiding this comment.
This test fails to meet the threshold if we use a vectorized "obs" inside an iarange.
@martinjankowiak, @fritzo - Is that expected? Is there a reason why the two obs are observed in separate statements?
There was a problem hiding this comment.
Just checked with @martinjankowiak that this works as expected inside iarange, so I am not sure why I was seeing a difference. Will make the change and update.
There was a problem hiding this comment.
@martinjankowiak can we just delete this flaky expensive test?
There was a problem hiding this comment.
yes we can. although maybe keep the transformed_distribution bit?
| raise NotImplementedError('alpha < 1 is not supported') | ||
| self.alpha = alpha | ||
| self._standard_gamma = Gamma(alpha, alpha.new([1]).expand_as(alpha)) | ||
| self._standard_gamma = Gamma(alpha, torch.empty_like(alpha).fill_(1).expand_as(alpha)) |
There was a problem hiding this comment.
@fritzo - Had to make a small change to support scalars. Is there a more concise way to do this?
There was a problem hiding this comment.
How about alpha.new([1]).squeeze().expand_as(alpha)?
|
The times for integration tests seem to have substantially increased with either the checks or the iarange/reshaping operations. |
|
@neeraj, is this ready to merge? |
fritzo
left a comment
There was a problem hiding this comment.
LGTM (but we'll need another reviewer since I authored much of this PR)
| @pytest.mark.init(rng_seed=161) | ||
| @pytest.mark.parametrize("batch_size", [3, 5, 7, 8, None]) | ||
| @pytest.mark.parametrize("map_type", ["tensor", "list"]) | ||
| @pytest.mark.parametrize("map_type", ["iarange", "irange", "range"]) |
There was a problem hiding this comment.
@neerajprad We're actually running more tests in integration_batch_1, so that might be part of the slowdown.
|
@martinjankowiak - could you take a look at this PR? It will be nice to get this merged soon before we accumulate merge conflicts. |
|
yeah i'll take a look now |
martinjankowiak
left a comment
There was a problem hiding this comment.
lgtm!
my questions are more for my own edification than anything lese
| z_dist = TransformedDistribution(dist.Normal(z_mu, z_sigma), self.iafs) | ||
| else: | ||
| z_dist = dist.Normal(z_mu, z_sigma) | ||
| assert z_dist.event_shape == () |
There was a problem hiding this comment.
when we clean-up the tutorials we should probably put asserts like this everywhere as a way of helping users understand the code
| prior_mu = Variable(torch.zeros([batch_size, self.z_dim])) | ||
| prior_sigma = Variable(torch.ones([batch_size, self.z_dim])) | ||
| zs = pyro.sample("z", dist.Normal(prior_mu, prior_sigma)) | ||
| zs = pyro.sample("z", dist.Normal(prior_mu, prior_sigma).reshape(extra_event_dims=2)) |
There was a problem hiding this comment.
why extra_event_dims=2? this isn't really necessary is it?
There was a problem hiding this comment.
I think that's an oversight on my part. We don't need any reshaping here. Will remove.
| variance = self.get_param("variance").expand_as(f) | ||
|
|
||
| return pyro.sample("y", dist.Normal(f, variance), obs=obs) | ||
| event_dims = f.dim() |
There was a problem hiding this comment.
won't event_dims always be 1 here?
There was a problem hiding this comment.
I made the change because of the failing gp tutorial, but I'm not too familiar with this code. If f.dim() = 1, we can hard-code that.
| prior_mu = Variable(torch.zeros([batch_size, self.z_dim])) | ||
| prior_sigma = Variable(torch.ones([batch_size, self.z_dim])) | ||
| zs = pyro.sample("z", dist.Normal(prior_mu, prior_sigma).reshape(extra_event_dims=2)) | ||
| zs = pyro.sample("z", dist.Normal(prior_mu, prior_sigma)) |
There was a problem hiding this comment.
Why don't you need to .reshape(extra_event_dims=1) here?
There was a problem hiding this comment.
From my understanding, model_sample is not used for inference but for generating samples separately using the trained decoder, for plotting, etc.
fritzo
left a comment
There was a problem hiding this comment.
LGTM (again, someone else must merge)
See Design Doc | Blocking #780
This PR adds checks for stricter use of batch dimensions in Pyro. This is only recently possible since it relies on PyTorch support for scalars.
Why?
Pyro 0.2 will be able to automatically introduce batch dimensions for parallelizing
enum_discreteandnum_particles. To give Pyro space to create new batch dimensions, we need to restrict how dimensions are introduced by user code.How?
The new requirements are:
.reshape(extra_event_dims=n)iarange(and maybeirange?)SVI(..., max_iarange_nesting=n)These requirements are checked by the new helper
check_site_shapewhich uses the new.sizefield incond_indep_stackframes.Tasks
Trace_ELBOTraceGraph_ELBO