Implement Transforms#4771
Conversation
|
@pytorchbot retest this please |
apaszke
left a comment
There was a problem hiding this comment.
Looks good, but I'd like to discuss a few things before merging this.
| from torch.distributions.transformed_distribution import TransformedDistribution | ||
|
|
||
|
|
||
| class HalfNormal(TransformedDistribution): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
|
||
| @register_kl(TransformedDistribution, TransformedDistribution) | ||
| def _kl_transformed_transformed(p, q): | ||
| if p.transforms != q.transforms: |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
|
||
| @register_kl(LogNormal, LogNormal) | ||
| def _kl_lognormal_lognormal(p, q): | ||
| return kl_divergence(p.base_dist, q.base_dist) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| if isinstance(transforms, Transform): | ||
| self.transforms = [transforms, ] | ||
| elif isinstance(transforms, list): | ||
| for transform in transforms: |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| try: | ||
| return self.transforms[-1].codomain | ||
| except IndexError: | ||
| return self.base_dist.support |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| """ | ||
|
|
||
| def __init__(self): | ||
| self._cache = {} |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| def __init__(self): | ||
| self._cache = {} | ||
|
|
||
| def forward(self, x): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| shape = x.shape | ||
| for _ in range(self.event_dim): | ||
| result = result.sum(-1) | ||
| shape = shape[:-1] |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| class StickBreakingTransform(Transform): | ||
| """ | ||
| Transform from the simplex to unconstrained of one fewer dimension via a | ||
| stick-breaking process. |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| """ | ||
| Transform from the simplex to unconstrained space via `y = log(x)`. | ||
|
|
||
| This is not bijective and cannot be used for HMC. However this acts mostly |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| self.scale = scale | ||
| self.event_dim = event_dim | ||
|
|
||
| def __eq__(self, other): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
I fixed the low-hanging comments. I also realized we forgot to add tests for |
| self._cache = {} | ||
|
|
||
| def __eq__(self, other): | ||
| return type(other) is type(self) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| if self.event_dim: | ||
| # NOTE: no need for contiguous here | ||
| result = result.view(*result.size()[:-self.event_dim], -1).sum(-1) | ||
| shape = shape[:-self.event_dim] |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
apaszke
left a comment
There was a problem hiding this comment.
Looks good. A few minor questions + please revert submodule change.
| """ | ||
| return self._call(x) | ||
|
|
||
| def inverse(self, y): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| """ | ||
| Inverts a single :class:`Transform`. | ||
| """ | ||
| __slots__ = ['inv'] |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| return self.inv.inverse(x) | ||
|
|
||
| def _inverse(self, y): | ||
| return self.inv.__call__(y) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
fritzo
left a comment
There was a problem hiding this comment.
Thanks for the detailed review, Adam!
| """ | ||
| return self._call(x) | ||
|
|
||
| def inverse(self, y): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| """ | ||
| Inverts a single :class:`Transform`. | ||
| """ | ||
| __slots__ = ['inv'] |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| return self.inv.inverse(x) | ||
|
|
||
| def _inverse(self, y): | ||
| return self.inv.__call__(y) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
apaszke
left a comment
There was a problem hiding this comment.
Looks good, but let's fix reference cycles.
| else: | ||
| raise NotImplementedError('cache_size must be 0 or 1') | ||
|
|
||
| @lazy_property |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| pass # default behavior | ||
| elif cache_size == 1: | ||
| self._cached_x_y = None, None | ||
| self.__call__ = self._cached_call |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
|
||
| @property | ||
| def inv(self): | ||
| return self._inv |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| try: | ||
| x_old, y_old = self._cached_x_y | ||
| except AttributeError: | ||
| return self._call(x) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| return y_old | ||
| y = self._call(x) | ||
| self._cached_x_y = x, y | ||
| return y |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
Thanks for your patience in reviewing, @apaszke! |
|
Happy to help! |
|
BTW should I squash this, or do you guys want to recreate some meaningful history by manual rebasing (squashed commit can only have a single author). |
|
Squash! |
This implements a class hierarchy of
Transforms, aTransformedDistributionclass, andLogNormalandHalfNormalclasses as examples ofTransformedDistributions.Transformobjects have.forward()and.inverse()methods. Most transforms are invertible and also have.log_abs_det_jacobian()methods. The forward and inverse methods are bidirectionally memoized so that e.g. when.forward()is called, subsequent calls to both.forward()and.inverse()are free.Not all transforms are invertible, e.g.
AbsTransformis a 2-to-1 mapping, andLogprobTransformis projects an entire dimension away. However both of these have.inverse()implemented as pseudoinverses. cc @apaszke