Skip to content

Implement Transforms#4771

Merged
apaszke merged 52 commits intopytorch:masterfrom
probtorch:bijector
Jan 28, 2018
Merged

Implement Transforms#4771
apaszke merged 52 commits intopytorch:masterfrom
probtorch:bijector

Conversation

@alicanb
Copy link
Copy Markdown
Collaborator

@alicanb alicanb commented Jan 21, 2018

This implements a class hierarchy of Transforms, a TransformedDistribution class, and LogNormal and HalfNormal classes as examples of TransformedDistributions. Transform objects have .forward() and .inverse() methods. Most transforms are invertible and also have .log_abs_det_jacobian() methods. The forward and inverse methods are bidirectionally memoized so that e.g. when .forward() is called, subsequent calls to both .forward() and .inverse() are free.

Not all transforms are invertible, e.g. AbsTransform is a 2-to-1 mapping, and LogprobTransform is projects an entire dimension away. However both of these have .inverse() implemented as pseudoinverses. cc @apaszke

@yf225
Copy link
Copy Markdown
Contributor

yf225 commented Jan 22, 2018

@pytorchbot retest this please

Copy link
Copy Markdown
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, but I'd like to discuss a few things before merging this.

Comment thread torch/distributions/half_normal.py Outdated
from torch.distributions.transformed_distribution import TransformedDistribution


class HalfNormal(TransformedDistribution):

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

Comment thread torch/distributions/kl.py

@register_kl(TransformedDistribution, TransformedDistribution)
def _kl_transformed_transformed(p, q):
if p.transforms != q.transforms:

This comment was marked as off-topic.

This comment was marked as off-topic.

Comment thread torch/distributions/kl.py Outdated

@register_kl(LogNormal, LogNormal)
def _kl_lognormal_lognormal(p, q):
return kl_divergence(p.base_dist, q.base_dist)

This comment was marked as off-topic.

if isinstance(transforms, Transform):
self.transforms = [transforms, ]
elif isinstance(transforms, list):
for transform in transforms:

This comment was marked as off-topic.

try:
return self.transforms[-1].codomain
except IndexError:
return self.base_dist.support

This comment was marked as off-topic.

Comment thread torch/distributions/transforms.py Outdated
"""

def __init__(self):
self._cache = {}

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

Comment thread torch/distributions/transforms.py Outdated
def __init__(self):
self._cache = {}

def forward(self, x):

This comment was marked as off-topic.

Comment thread torch/distributions/transforms.py Outdated
shape = x.shape
for _ in range(self.event_dim):
result = result.sum(-1)
shape = shape[:-1]

This comment was marked as off-topic.

This comment was marked as off-topic.

Comment thread torch/distributions/transforms.py Outdated
class StickBreakingTransform(Transform):
"""
Transform from the simplex to unconstrained of one fewer dimension via a
stick-breaking process.

This comment was marked as off-topic.

This comment was marked as off-topic.

Comment thread torch/distributions/transforms.py Outdated
"""
Transform from the simplex to unconstrained space via `y = log(x)`.

This is not bijective and cannot be used for HMC. However this acts mostly

This comment was marked as off-topic.

This comment was marked as off-topic.

Comment thread torch/distributions/transforms.py Outdated
self.scale = scale
self.event_dim = event_dim

def __eq__(self, other):

This comment was marked as off-topic.

This comment was marked as off-topic.

@alicanb
Copy link
Copy Markdown
Collaborator Author

alicanb commented Jan 22, 2018

I fixed the low-hanging comments. I also realized we forgot to add tests for LogNormal and HalfNormal, so I added them as well. I think scipy's lognorm.entropy() link is incorrect. so I omitted that test (it passes the monte carlo test).

Comment thread torch/distributions/transforms.py Outdated
self._cache = {}

def __eq__(self, other):
return type(other) is type(self)

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

Comment thread torch/distributions/transforms.py Outdated
if self.event_dim:
# NOTE: no need for contiguous here
result = result.view(*result.size()[:-self.event_dim], -1).sum(-1)
shape = shape[:-self.event_dim]

This comment was marked as off-topic.

Copy link
Copy Markdown
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. A few minor questions + please revert submodule change.

Comment thread torch/distributions/transforms.py Outdated
"""
return self._call(x)

def inverse(self, y):

This comment was marked as off-topic.

This comment was marked as off-topic.

Comment thread torch/distributions/transforms.py Outdated
"""
Inverts a single :class:`Transform`.
"""
__slots__ = ['inv']

This comment was marked as off-topic.

This comment was marked as off-topic.

Comment thread torch/distributions/transforms.py Outdated
return self.inv.inverse(x)

def _inverse(self, y):
return self.inv.__call__(y)

This comment was marked as off-topic.

This comment was marked as off-topic.

Copy link
Copy Markdown
Collaborator

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the detailed review, Adam!

Comment thread torch/distributions/transforms.py Outdated
"""
return self._call(x)

def inverse(self, y):

This comment was marked as off-topic.

Comment thread torch/distributions/transforms.py Outdated
"""
Inverts a single :class:`Transform`.
"""
__slots__ = ['inv']

This comment was marked as off-topic.

Comment thread torch/distributions/transforms.py Outdated
return self.inv.inverse(x)

def _inverse(self, y):
return self.inv.__call__(y)

This comment was marked as off-topic.

Copy link
Copy Markdown
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, but let's fix reference cycles.

Comment thread torch/distributions/transforms.py Outdated
else:
raise NotImplementedError('cache_size must be 0 or 1')

@lazy_property

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

Comment thread torch/distributions/transforms.py Outdated
pass # default behavior
elif cache_size == 1:
self._cached_x_y = None, None
self.__call__ = self._cached_call

This comment was marked as off-topic.

This comment was marked as off-topic.


@property
def inv(self):
return self._inv

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

try:
x_old, y_old = self._cached_x_y
except AttributeError:
return self._call(x)

This comment was marked as off-topic.

return y_old
y = self._call(x)
self._cached_x_y = x, y
return y

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@fritzo
Copy link
Copy Markdown
Collaborator

fritzo commented Jan 28, 2018

Thanks for your patience in reviewing, @apaszke!

@apaszke
Copy link
Copy Markdown
Contributor

apaszke commented Jan 28, 2018

Happy to help!

@apaszke
Copy link
Copy Markdown
Contributor

apaszke commented Jan 28, 2018

BTW should I squash this, or do you guys want to recreate some meaningful history by manual rebasing (squashed commit can only have a single author).

@fritzo
Copy link
Copy Markdown
Collaborator

fritzo commented Jan 28, 2018

Squash!

@apaszke apaszke merged commit 967bceb into pytorch:master Jan 28, 2018
@apaszke
Copy link
Copy Markdown
Contributor

apaszke commented Jan 28, 2018

Thanks @fritzo @alicanb!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants