Skip to content

Wrap torch Categorical and OneHotCategorical#649

Merged
neerajprad merged 7 commits intodevfrom
wrap-torch-categorical
Dec 22, 2017
Merged

Wrap torch Categorical and OneHotCategorical#649
neerajprad merged 7 commits intodevfrom
wrap-torch-categorical

Conversation

@fritzo
Copy link
Copy Markdown
Member

@fritzo fritzo commented Dec 22, 2017

Addresses #606

This wraps torch.distributions.Categorical for use in Pyro as both Categorical and OneHotCategorical. PyTorch does not yet have a native OneHotCategorical. The advantage of migrating both distributions now is that the PyTorch version supports full broadcasting and will be easier to update to our new batching semantics.

This also fixes and updates some of the tests/infer/test_inference.py tests:

  • Correctly splits reparameterized/nonreparameterized tests now that Gamma and Beta are reparameterized
  • Marks one of the Beta tests as xfail (@martinjankowiak and @fritzo are debugging it, this is likely due to low-precision numerics in our PyTorch code and will soon be made more precise)
  • Eliminates map_data() from those tests, reducing cost by ~30%

Tested

  • Ran make test-torch-dist against test-rsample branch of PyTorch.

@fritzo
Copy link
Copy Markdown
Member Author

fritzo commented Dec 22, 2017

@neerajprad the Categorical and OneHotCategorical wrappers appear to work now.

@martinjankowiak could you please review changes to test_inference.py?

ps = self.torch_dist.probs.data
zero = ps.new(self._sample_shape + ps.shape).zero_()
indices = super(OneHotCategorical, self).sample().data
if indices.dim() < zero.dim():
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this for when we have scalar support in PyTorch? sample currently should return a non-zero dim tensor/variable.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is to compensate for Pyro distributions adding an extra dim in all cases (for consistency), whereas I believe PyTorch currently adds an extra dim only to tensors that would be scalars but can't yet be. But honestly I just type random stuff until tests pass 😊

Copy link
Copy Markdown
Collaborator

@martinjankowiak martinjankowiak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_inference changes lgtm


@pytest.mark.skipif(not dist.gamma.reparameterized, reason='not implemented')
def test_elbo_reparameterized(self):
self.do_elbo_test(True, 10000)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you change any of these step counts?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see below, the only change was a 10001 -> 10000 in another test

svi = SVI(model, guide, adam, loss="ELBO", trace_graph=False)

for k in range(10001):
for k in range(n_steps):
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@martinjankowiak The only change I made to step counts was to change this from 10001 to 10000. What was the extra 1 for?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nothing important just checking. the 1 was so it'd print on the last step.... (so irrelevant)


def test_enum_discrete_global_local_error():
if dist.USE_TORCH_DISTRIBUTIONS:
pytest.xfail(reason="torch Bernoulli is too permissive?")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this failing because of some broadcasting issue?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe so. This is a failing death test, so the failure is due to lack of an error being raised. Pyro's old distributions did not allow broadcasting and PyTorch does. I think it's safe to put this off until just before release.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! Just want to know all broadcasting related failures / unexpected outcomes, so that if there is a need to revisit it on the PyTorch side, we can do that since its still early days.

@neerajprad neerajprad merged commit a8f24b5 into dev Dec 22, 2017
@fritzo
Copy link
Copy Markdown
Member Author

fritzo commented Dec 22, 2017

Thanks fore reviewing!

@fritzo fritzo deleted the wrap-torch-categorical branch January 5, 2018 16:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants