Skip to content

Adopt strict batch shape semantics for distributions#806

Merged
martinjankowiak merged 51 commits intodevfrom
strict-shape
Mar 1, 2018
Merged

Adopt strict batch shape semantics for distributions#806
martinjankowiak merged 51 commits intodevfrom
strict-shape

Conversation

@fritzo
Copy link
Copy Markdown
Member

@fritzo fritzo commented Feb 23, 2018

See Design Doc | Blocking #780

This PR adds checks for stricter use of batch dimensions in Pyro. This is only recently possible since it relies on PyTorch support for scalars.

Why?

Pyro 0.2 will be able to automatically introduce batch dimensions for parallelizing enum_discrete and num_particles. To give Pyro space to create new batch dimensions, we need to restrict how dimensions are introduced by user code.

How?

The new requirements are:

  • users must declare non-batch dimensions using .reshape(extra_event_dims=n)
  • batch dimensions will only be allowed inside iarange (and maybe irange?)
  • to use auto-parallelization, users must request dims via SVI(..., max_iarange_nesting=n)

These requirements are checked by the new helper check_site_shape which uses the new .size field in cond_indep_stack frames.

Tasks

  • check sites in Trace_ELBO
  • check sites in TraceGraph_ELBO
  • update tests/infer/test_valid_models.py
  • update other tests
  • update examples
  • update tutorials
  • document new batching semantics

@fritzo fritzo added the WIP label Feb 23, 2018
@fritzo
Copy link
Copy Markdown
Member Author

fritzo commented Feb 27, 2018

@neerajprad I've merged dev with your #824 into this PR. Lots of tests still fail. Feel free to push commits to this branch. I'll chat you if I start working on it again (at earliest Wednesday).

Comment thread pyro/util.py Outdated
'- .permute() data dimensions']))

# Check parallel dimensions on the left of max_iarange_nesting.
# TODO
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fritzo : Do you want to add some more checks in this PR?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, they should be added in later PRs.

Comment thread tests/infer/test_inference.py Outdated
sigma = torch.pow(self.tau, -0.5)
pyro.observe("obs0", dist.LogNormal(mu_latent, sigma), obs=self.data[0])
pyro.observe("obs1", dist.LogNormal(mu_latent, sigma), obs=self.data[1])
pyro.observe("obs0", dist.LogNormal(mu_latent, sigma), obs=self.data[0].squeeze())
Copy link
Copy Markdown
Member

@neerajprad neerajprad Feb 28, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test fails to meet the threshold if we use a vectorized "obs" inside an iarange.

@martinjankowiak, @fritzo - Is that expected? Is there a reason why the two obs are observed in separate statements?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checked with @martinjankowiak that this works as expected inside iarange, so I am not sure why I was seeing a difference. Will make the change and update.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@martinjankowiak can we just delete this flaky expensive test?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes we can. although maybe keep the transformed_distribution bit?

raise NotImplementedError('alpha < 1 is not supported')
self.alpha = alpha
self._standard_gamma = Gamma(alpha, alpha.new([1]).expand_as(alpha))
self._standard_gamma = Gamma(alpha, torch.empty_like(alpha).fill_(1).expand_as(alpha))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fritzo - Had to make a small change to support scalars. Is there a more concise way to do this?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about alpha.new([1]).squeeze().expand_as(alpha)?

@neerajprad
Copy link
Copy Markdown
Member

The times for integration tests seem to have substantially increased with either the checks or the iarange/reshaping operations.

@fritzo
Copy link
Copy Markdown
Member Author

fritzo commented Mar 1, 2018

@neeraj, is this ready to merge?

Copy link
Copy Markdown
Member Author

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM (but we'll need another reviewer since I authored much of this PR)

@pytest.mark.init(rng_seed=161)
@pytest.mark.parametrize("batch_size", [3, 5, 7, 8, None])
@pytest.mark.parametrize("map_type", ["tensor", "list"])
@pytest.mark.parametrize("map_type", ["iarange", "irange", "range"])
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@neerajprad We're actually running more tests in integration_batch_1, so that might be part of the slowdown.

@neerajprad
Copy link
Copy Markdown
Member

@martinjankowiak - could you take a look at this PR? It will be nice to get this merged soon before we accumulate merge conflicts.

@martinjankowiak
Copy link
Copy Markdown
Collaborator

yeah i'll take a look now

martinjankowiak
martinjankowiak previously approved these changes Mar 1, 2018
Copy link
Copy Markdown
Collaborator

@martinjankowiak martinjankowiak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

my questions are more for my own edification than anything lese

Comment thread examples/dmm/dmm.py
z_dist = TransformedDistribution(dist.Normal(z_mu, z_sigma), self.iafs)
else:
z_dist = dist.Normal(z_mu, z_sigma)
assert z_dist.event_shape == ()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when we clean-up the tutorials we should probably put asserts like this everywhere as a way of helping users understand the code

Comment thread examples/vae.py Outdated
prior_mu = Variable(torch.zeros([batch_size, self.z_dim]))
prior_sigma = Variable(torch.ones([batch_size, self.z_dim]))
zs = pyro.sample("z", dist.Normal(prior_mu, prior_sigma))
zs = pyro.sample("z", dist.Normal(prior_mu, prior_sigma).reshape(extra_event_dims=2))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why extra_event_dims=2? this isn't really necessary is it?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's an oversight on my part. We don't need any reshaping here. Will remove.

variance = self.get_param("variance").expand_as(f)

return pyro.sample("y", dist.Normal(f, variance), obs=obs)
event_dims = f.dim()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

won't event_dims always be 1 here?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made the change because of the failing gp tutorial, but I'm not too familiar with this code. If f.dim() = 1, we can hard-code that.

Comment thread examples/vae.py
prior_mu = Variable(torch.zeros([batch_size, self.z_dim]))
prior_sigma = Variable(torch.ones([batch_size, self.z_dim]))
zs = pyro.sample("z", dist.Normal(prior_mu, prior_sigma).reshape(extra_event_dims=2))
zs = pyro.sample("z", dist.Normal(prior_mu, prior_sigma))
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't you need to .reshape(extra_event_dims=1) here?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my understanding, model_sample is not used for inference but for generating samples separately using the trained decoder, for plotting, etc.

Copy link
Copy Markdown
Member Author

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM (again, someone else must merge)

@martinjankowiak martinjankowiak merged commit fe32be3 into dev Mar 1, 2018
@martinjankowiak martinjankowiak deleted the strict-shape branch March 1, 2018 20:17
@neerajprad neerajprad mentioned this pull request Mar 2, 2018
16 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants