Conversation
|
cc @rachtsingh @tbrx |
|
|
||
| X ~ BaseDistribution | ||
| Y = f(X) ~ TransformedDistribution(BaseDistribution, f) | ||
| log p(Y) = log p(X) + log det (dX/dY) |
There was a problem hiding this comment.
This should be log |det (dX/dY)| right?
| shape = self.base_dist.batch_shape + self.base_dist.event_shape | ||
| event_dim = max([len(self.base_dist.event_shape)] + [t.event_dim for t in self.transforms]) | ||
| batch_shape = shape[:len(shape) - event_dim] | ||
| event_shape = shape[len(shape) - event_dim:] |
There was a problem hiding this comment.
I think these can just be shape[:-event_dim] and shape[-event_dim:].
There was a problem hiding this comment.
Careful, negative indexing doesn't work when the argument is zero:
shape = [1,2,3]
assert shape[-3:] == [1, 2, 3]
assert shape[-2:] == [2, 3]
assert shape[-1:] == [3]
assert shape[-0:] == [1, 2, 3] # <---- this has tripped me up so oftenThere was a problem hiding this comment.
Got it; I figured there was probably a reason you were doing it that way. Thanks for letting me know!
|
Looks great! I'll double check that this solves the earlier problem, though it certainly should. One thing to note - we don't currently have any transforms with negative event_dim, but we could in the future. Not worth worrying about until we get there I think. |
|
@neerajprad Does this look reasonable to you? We've addressed similar issues in Pyro lately. |
82e48fb to
605cace
Compare
|
Hmm, actually, when I run
so it looks like you'll need to add |
| x = dist.rsample() | ||
| try: | ||
| dist.log_prob(x) # this should not crash | ||
| except NotImplementedError: |
There was a problem hiding this comment.
BoltzmannTransform does not implement .log_abs_det_jacobian(), but I wanted to use it as an example of a transform with event_dim=1. It will be nice when we have more multivariate examples to test with.
alicanb
left a comment
There was a problem hiding this comment.
LGTM. You might want to keep [] as default value in TransformedDistribution's init, but let's be honest: who on earth would try to create a TransformedDistribution without the transform
|
@rachtsingh Tests should work now, thanks for pointing out the error! |
I figured there are more users who would appreciate the feedback that "an additional argument is required" than there are to would appreciate the convenience of not having to type |
|
Thanks for reviewing @rachtsingh and @alicanb! I'll move upstream to pytorch#4937 |
This adds an
.event_dimattribute to allTransforms and correctly handles event shape inTransformedDistribution.log_prob(). Cases that we need to handle are:TransformedDistribution.base_disthas a largerevent_dimthen its transforms, we need to sum out some of the rightmost dimensions in thetransform.log_abs_det_jacobian()results, otherwise there will be a shape error.TransformedDistribution.base_disthas a smallerevent_dimthan its transforms (e.g. when implementingMultivariateNormalas anAffineOperatorTransformof univariateNormal), we need to sum out the rightmost dimensions ofbase_dist.log_prob().event_dim, we need to sum out all but the largest dim.This PR also includes fixes to
ComposeTransform.event_dimandTransformedDistribution.event_shape.This PR was the result of issues that came up in @rachtsingh's #113 and in @fritzo's refactoring of
InverseAutoregressiveFlowin Pyro.