Wrap torch Categorical and OneHotCategorical#649
Conversation
|
@neerajprad the @martinjankowiak could you please review changes to |
| ps = self.torch_dist.probs.data | ||
| zero = ps.new(self._sample_shape + ps.shape).zero_() | ||
| indices = super(OneHotCategorical, self).sample().data | ||
| if indices.dim() < zero.dim(): |
There was a problem hiding this comment.
Is this for when we have scalar support in PyTorch? sample currently should return a non-zero dim tensor/variable.
There was a problem hiding this comment.
I believe this is to compensate for Pyro distributions adding an extra dim in all cases (for consistency), whereas I believe PyTorch currently adds an extra dim only to tensors that would be scalars but can't yet be. But honestly I just type random stuff until tests pass 😊
martinjankowiak
left a comment
There was a problem hiding this comment.
test_inference changes lgtm
|
|
||
| @pytest.mark.skipif(not dist.gamma.reparameterized, reason='not implemented') | ||
| def test_elbo_reparameterized(self): | ||
| self.do_elbo_test(True, 10000) |
There was a problem hiding this comment.
did you change any of these step counts?
There was a problem hiding this comment.
see below, the only change was a 10001 -> 10000 in another test
| svi = SVI(model, guide, adam, loss="ELBO", trace_graph=False) | ||
|
|
||
| for k in range(10001): | ||
| for k in range(n_steps): |
There was a problem hiding this comment.
@martinjankowiak The only change I made to step counts was to change this from 10001 to 10000. What was the extra 1 for?
There was a problem hiding this comment.
nothing important just checking. the 1 was so it'd print on the last step.... (so irrelevant)
|
|
||
| def test_enum_discrete_global_local_error(): | ||
| if dist.USE_TORCH_DISTRIBUTIONS: | ||
| pytest.xfail(reason="torch Bernoulli is too permissive?") |
There was a problem hiding this comment.
Is this failing because of some broadcasting issue?
There was a problem hiding this comment.
I believe so. This is a failing death test, so the failure is due to lack of an error being raised. Pyro's old distributions did not allow broadcasting and PyTorch does. I think it's safe to put this off until just before release.
There was a problem hiding this comment.
Sounds good! Just want to know all broadcasting related failures / unexpected outcomes, so that if there is a need to revisit it on the PyTorch side, we can do that since its still early days.
|
Thanks fore reviewing! |
Addresses #606
This wraps
torch.distributions.Categoricalfor use in Pyro as bothCategoricalandOneHotCategorical. PyTorch does not yet have a nativeOneHotCategorical. The advantage of migrating both distributions now is that the PyTorch version supports full broadcasting and will be easier to update to our new batching semantics.This also fixes and updates some of the
tests/infer/test_inference.pytests:GammaandBetaare reparameterizedBetatests as xfail (@martinjankowiak and @fritzo are debugging it, this is likely due to low-precision numerics in our PyTorch code and will soon be made more precise)map_data()from those tests, reducing cost by ~30%Tested
make test-torch-distagainsttest-rsamplebranch of PyTorch.