Support multivariate TransformedDistributions#4937
Conversation
|
cc @apaszke |
| result += part.log_abs_det_jacobian(x, y) | ||
| term = part.log_abs_det_jacobian(x, y) | ||
| for _ in range(self.event_dim - part.event_dim): | ||
| term = term.sum(-1) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| the codomain. Transforms that are not bijective should at least | ||
| maintain the weaker pseudoinverse properties | ||
| ``t(t.inv(t(x)) == t(x)`` and ``t.inv(t(t.inv(y))) == t.inv(y)``. | ||
| event_dim (int): Number of dimensions in the transform ``event_shape``. |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
LGTM, can you please resolve the conflicts? |
This reverts commit ca5071d.
* Revert "Clarify grad_input_mask documentation in derivatives.yaml (#4963)" This reverts commit 6f3266b. * Revert "fix triu and tril for zero-strided inputs on gpu (#4962)" This reverts commit 6c197c2. * Revert "Add mutex for CPU RNG and move TH to C++ (#4041)" This reverts commit 96239dd. * Revert "Support multivariate TransformedDistributions (#4937)" This reverts commit ca5071d. * Revert "Only check that arguments are Variables in VariableType (#4943)" This reverts commit d444379. * Revert "torch.set_num_threads sets MKL option too (#4949)" This reverts commit 2aaeec0.
* Revert "Clarify grad_input_mask documentation in derivatives.yaml (pytorch#4963)" This reverts commit 37c8454. * Revert "fix triu and tril for zero-strided inputs on gpu (pytorch#4962)" This reverts commit 9acb9be. * Revert "Add mutex for CPU RNG and move TH to C++ (pytorch#4041)" This reverts commit 07b7fe2. * Revert "Support multivariate TransformedDistributions (pytorch#4937)" This reverts commit 18bdb4a. * Revert "Only check that arguments are Variables in VariableType (pytorch#4943)" This reverts commit a479ca0. * Revert "torch.set_num_threads sets MKL option too (pytorch#4949)" This reverts commit 2cfb339.
Reviewed by @rachtsingh and @alicanb at probtorch#116
This adds an
.event_dimattribute to allTransforms and correctly handles event shape inTransformedDistribution.log_prob()andComposeTransform.log_abs_det_jacobian(). Cases we need to handle are:TransformedDistribution.base_disthas a largerevent_dimthan its transforms, we need to sum out the rightmost dimensions in thetransform.log_abs_det_jacobian()s, otherwise there will be a shape error.TransformedDistribution.base_disthas a smallerevent_dimthan its transforms (e.g. when implementingMultivariateNormalas anAffineOperatorTransformof univariateNormal), we need to sum out the rightmost dimensions ofbase_dist.log_prob().event_dim, we need to sum out all but the largest dim.This PR also includes fixes to
ComposeTransform.event_dimandTransformedDistribution.event_shapeto support multivariate transforms.This PR was the result of issues that came up in @rachtsingh's probtorch#113 and in @fritzo's refactoring of
InverseAutoregressiveFlowin Pyro as we build on top of torch.distributions.transforms.Tested
TransformedDistributionTransformshapes