Skip to content

Correctly handle event_dim in Transforms#116

Closed
fritzo wants to merge 5 commits intomasterfrom
transform-event-dim
Closed

Correctly handle event_dim in Transforms#116
fritzo wants to merge 5 commits intomasterfrom
transform-event-dim

Conversation

@fritzo
Copy link
Copy Markdown

@fritzo fritzo commented Jan 30, 2018

This adds an .event_dim attribute to all Transforms and correctly handles event shape in TransformedDistribution.log_prob(). Cases that we need to handle are:

  • When TransformedDistribution.base_dist has a larger event_dim then its transforms, we need to sum out some of the rightmost dimensions in the transform.log_abs_det_jacobian() results, otherwise there will be a shape error.
  • When TransformedDistribution.base_dist has a smaller event_dim than its transforms (e.g. when implementing MultivariateNormal as an AffineOperatorTransform of univariate Normal), we need to sum out the rightmost dimensions of base_dist.log_prob().
  • When transforms have differing event_dim, we need to sum out all but the largest dim.

This PR also includes fixes to ComposeTransform.event_dim and TransformedDistribution.event_shape.

This PR was the result of issues that came up in @rachtsingh's #113 and in @fritzo's refactoring of InverseAutoregressiveFlow in Pyro.

  • add tests

@fritzo
Copy link
Copy Markdown
Author

fritzo commented Jan 30, 2018

cc @rachtsingh @tbrx

@fritzo fritzo mentioned this pull request Jan 30, 2018
2 tasks

X ~ BaseDistribution
Y = f(X) ~ TransformedDistribution(BaseDistribution, f)
log p(Y) = log p(X) + log det (dX/dY)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be log |det (dX/dY)| right?

shape = self.base_dist.batch_shape + self.base_dist.event_shape
event_dim = max([len(self.base_dist.event_shape)] + [t.event_dim for t in self.transforms])
batch_shape = shape[:len(shape) - event_dim]
event_shape = shape[len(shape) - event_dim:]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these can just be shape[:-event_dim] and shape[-event_dim:].

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Careful, negative indexing doesn't work when the argument is zero:

shape = [1,2,3]
assert shape[-3:] == [1, 2, 3]
assert shape[-2:] == [2, 3]
assert shape[-1:] == [3]
assert shape[-0:] == [1, 2, 3]  # <---- this has tripped me up so often

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it; I figured there was probably a reason you were doing it that way. Thanks for letting me know!

@rachtsingh
Copy link
Copy Markdown

Looks great! I'll double check that this solves the earlier problem, though it certainly should.

One thing to note - we don't currently have any transforms with negative event_dim, but we could in the future. Not worth worrying about until we get there I think.

@fritzo
Copy link
Copy Markdown
Author

fritzo commented Jan 30, 2018

@neerajprad Does this look reasonable to you? We've addressed similar issues in Pyro lately.

@fritzo fritzo force-pushed the transform-event-dim branch from 82e48fb to 605cace Compare January 30, 2018 16:08
@rachtsingh
Copy link
Copy Markdown

rachtsingh commented Jan 30, 2018

Hmm, actually, when I run test_distributions.py on your branch, I get:

AssertionError: Please add TransformedDistribution to the EXAMPLES list in test_distributions.py

so it looks like you'll need to add TransformedDistribution to the list of exceptions.

x = dist.rsample()
try:
dist.log_prob(x) # this should not crash
except NotImplementedError:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when do we need this?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BoltzmannTransform does not implement .log_abs_det_jacobian(), but I wanted to use it as an example of a transform with event_dim=1. It will be nice when we have more multivariate examples to test with.

Copy link
Copy Markdown
Collaborator

@alicanb alicanb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. You might want to keep [] as default value in TransformedDistribution's init, but let's be honest: who on earth would try to create a TransformedDistribution without the transform

@fritzo
Copy link
Copy Markdown
Author

fritzo commented Jan 30, 2018

@rachtsingh Tests should work now, thanks for pointing out the error!

@fritzo
Copy link
Copy Markdown
Author

fritzo commented Jan 30, 2018

You might want to keep [] as default value

I figured there are more users who would appreciate the feedback that "an additional argument is required" than there are to would appreciate the convenience of not having to type , [] to get the identity-transformed distribution. Users who want the latter convenience should simply type TheirBaseDistribution without using TransformedDistribution at all 😄

@fritzo
Copy link
Copy Markdown
Author

fritzo commented Jan 30, 2018

Thanks for reviewing @rachtsingh and @alicanb! I'll move upstream to pytorch#4937

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants