Skip to content

Commit a347c74

Browse files
fritzofacebook-github-bot
authored andcommitted
Fix TransformedDistribution shaping logic (#50581)
Summary: Fixes #50496 Fixes #34859 Fixes #21596 This fixes many bugs involving `TransformedDistribution` and `ComposeTransform` when the component transforms changed their event shapes. Part of the fix is to introduce an `IndependentTransform` analogous to `distributions.Independent` and `constraints.independent`, and to introduce methods `Transform.forward_shape()` and `.inverse_shape()`. I have followed fehiepsi's suggestion and replaced `.input_event_dim` -> `.domain.event_dim` and `.output_event_dim` -> `.codomain.event_dim`. This allows us to deprecate `.event_dim` as an attribute. ## Summary of changes - Fixes `TransformDistribution` and `ComposeTransform` shape errors. - Fixes a behavior bug in `LogisticNormal`. - Fixes `kl_divergence(TransformedDistribution, TransformedDistribution)` - Adds methods `Transform.forward_shape()`, `.inverse_shape()` which are required for correct shape computations in `TransformedDistribution` and `ComposeTransform`. - Adds an `IndependentTransform`. - Adds a `ReshapeTransform` which is invaluable in testing shape logic in `ComposeTransform` and `TransformedDistribution` and which will be used by stefanwebb flowtorch. - Fixes incorrect default values in `constraints.dependent.event_dim`. - Documents the `.event_dim` and `.is_discrete` attributes. ## Changes planned for follow-up PRs - Memoize `constraints.dependent_property` as we do with `lazy_property`, since we now consult those properties much more often. ## Tested - [x] added a test for `Dist.support` vs `Dist(**params).support` to ensure static and dynamic attributes agree. - [x] refactoring is covered by existing tests - [x] add test cases for `ReshapedTransform` - [x] add a test for `TransformedDistribution` on a wide grid of input shapes - [x] added a regression test for #34859 cc fehiepsi feynmanliang stefanwebb Pull Request resolved: #50581 Reviewed By: ezyang, glaringlee, jpchen Differential Revision: D26024247 Pulled By: neerajprad fbshipit-source-id: f0b9a296f780ff49659b132409e11a29985dde9b
1 parent 250c711 commit a347c74

15 files changed

Lines changed: 557 additions & 124 deletions

test/distributions/test_constraints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def test_biject_to(constraint_fn, args, is_cuda):
6868
assert torch.allclose(x, x2), "Error in biject_to({}) inverse".format(constraint)
6969

7070
j = t.log_abs_det_jacobian(x, y)
71-
assert j.shape == x.shape[:x.dim() - t.input_event_dim]
71+
assert j.shape == x.shape[:x.dim() - t.domain.event_dim]
7272

7373

7474
@pytest.mark.parametrize('constraint_fn, args', [(c[0], c[1:]) for c in CONSTRAINTS])

test/distributions/test_distributions.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -878,6 +878,22 @@ def test_has_examples(self):
878878
self.assertIn(Dist, distributions_with_examples,
879879
"Please add {} to the EXAMPLES list in test_distributions.py".format(Dist.__name__))
880880

881+
def test_support_attributes(self):
882+
for Dist, params in EXAMPLES:
883+
for param in params:
884+
d = Dist(**param)
885+
event_dim = len(d.event_shape)
886+
self.assertEqual(d.support.event_dim, event_dim)
887+
try:
888+
self.assertEqual(Dist.support.event_dim, event_dim)
889+
except NotImplementedError:
890+
pass
891+
is_discrete = d.support.is_discrete
892+
try:
893+
self.assertEqual(Dist.support.is_discrete, is_discrete)
894+
except NotImplementedError:
895+
pass
896+
881897
def test_distribution_expand(self):
882898
shapes = [torch.Size(), torch.Size((2,)), torch.Size((2, 1))]
883899
for Dist, params in EXAMPLES:
@@ -1620,8 +1636,8 @@ def test_logisticnormal(self):
16201636
self.assertEqual(LogisticNormal(mean, std).sample((7,)).size(), (7, 5, 6))
16211637
self.assertEqual(LogisticNormal(mean_1d, std_1d).sample((1,)).size(), (1, 2))
16221638
self.assertEqual(LogisticNormal(mean_1d, std_1d).sample().size(), (2,))
1623-
self.assertEqual(LogisticNormal(0.2, .6).sample((1,)).size(), (2,))
1624-
self.assertEqual(LogisticNormal(-0.7, 50.0).sample((1,)).size(), (2,))
1639+
self.assertEqual(LogisticNormal(0.2, .6).sample().size(), (2,))
1640+
self.assertEqual(LogisticNormal(-0.7, 50.0).sample().size(), (2,))
16251641

16261642
# sample check for extreme value of mean, std
16271643
set_rng_seed(1)
@@ -3832,6 +3848,16 @@ def test_kl_shape(self):
38323848
'Actual {}'.format(kl.shape),
38333849
]))
38343850

3851+
def test_kl_transformed(self):
3852+
# Regression test for https://github.com/pytorch/pytorch/issues/34859
3853+
scale = torch.ones(2, 3)
3854+
loc = torch.zeros(2, 3)
3855+
normal = Normal(loc=loc, scale=scale)
3856+
diag_normal = Independent(normal, reinterpreted_batch_ndims=1)
3857+
trans_dist = TransformedDistribution(diag_normal, AffineTransform(loc=0., scale=2.))
3858+
self.assertEqual(kl_divergence(diag_normal, diag_normal).shape, (2,))
3859+
self.assertEqual(kl_divergence(trans_dist, trans_dist).shape, (2,))
3860+
38353861
def test_entropy_monte_carlo(self):
38363862
set_rng_seed(0) # see Note [Randomized statistical tests]
38373863
for Dist, params in EXAMPLES:

test/distributions/test_transforms.py

Lines changed: 114 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44

55
import torch
66
from torch.autograd.functional import jacobian
7-
from torch.distributions import Dirichlet, Normal, TransformedDistribution, constraints
7+
from torch.distributions import Dirichlet, Independent, Normal, TransformedDistribution, constraints
88
from torch.distributions.transforms import (AbsTransform, AffineTransform, ComposeTransform,
9-
CorrCholeskyTransform, ExpTransform,
10-
LowerCholeskyTransform, PowerTransform,
9+
CorrCholeskyTransform, ExpTransform, IndependentTransform,
10+
LowerCholeskyTransform, PowerTransform, ReshapeTransform,
1111
SigmoidTransform, TanhTransform, SoftmaxTransform,
1212
StickBreakingTransform, identity_transform, Transform,
1313
_InverseTransform)
@@ -22,6 +22,8 @@ def get_transforms(cache_size):
2222
cache_size=cache_size),
2323
PowerTransform(exponent=torch.tensor(5.).normal_(),
2424
cache_size=cache_size),
25+
PowerTransform(exponent=torch.tensor(5.).normal_(),
26+
cache_size=cache_size),
2527
SigmoidTransform(cache_size=cache_size),
2628
TanhTransform(cache_size=cache_size),
2729
AffineTransform(0, 1, cache_size=cache_size),
@@ -57,6 +59,12 @@ def get_transforms(cache_size):
5759
torch.randn(4, 5),
5860
cache_size=cache_size),
5961
]),
62+
ReshapeTransform((4, 5), (2, 5, 2)),
63+
IndependentTransform(
64+
AffineTransform(torch.randn(5),
65+
torch.randn(5),
66+
cache_size=cache_size),
67+
1),
6068
]
6169
transforms += [t.inv for t in transforms]
6270
return transforms
@@ -92,7 +100,16 @@ def transform_id(x):
92100

93101
def generate_data(transform):
94102
torch.manual_seed(1)
103+
while isinstance(transform, IndependentTransform):
104+
transform = transform.base_transform
105+
if isinstance(transform, ReshapeTransform):
106+
return torch.randn(transform.in_shape)
107+
if isinstance(transform.inv, ReshapeTransform):
108+
return torch.randn(transform.inv.out_shape)
95109
domain = transform.domain
110+
while (isinstance(domain, constraints.independent) and
111+
domain is not constraints.real_vector):
112+
domain = domain.base_constraint
96113
codomain = transform.codomain
97114
x = torch.empty(4, 5)
98115
if domain is constraints.lower_cholesky or codomain is constraints.lower_cholesky:
@@ -170,13 +187,15 @@ def test_forward_inverse(transform, test_cached):
170187
y = transform(x)
171188
except NotImplementedError:
172189
pytest.skip('Not implemented.')
190+
assert y.shape == transform.forward_shape(x.shape)
173191
if test_cached:
174192
x2 = transform.inv(y) # should be implemented at least by caching
175193
else:
176194
try:
177195
x2 = transform.inv(y.clone()) # bypass cache
178196
except NotImplementedError:
179197
pytest.skip('Not implemented.')
198+
assert x2.shape == transform.inverse_shape(y.shape)
180199
y2 = transform(x2)
181200
if transform.bijective:
182201
# verify function inverse
@@ -316,25 +335,29 @@ def test_jacobian(transform):
316335
except NotImplementedError:
317336
pytest.skip('Not implemented.')
318337
# Test shape
319-
target_shape = x.shape[:x.dim() - transform.input_event_dim]
338+
target_shape = x.shape[:x.dim() - transform.domain.event_dim]
320339
assert actual.shape == target_shape
321340

322341
# Expand if required
323342
transform = reshape_transform(transform, x.shape)
324343
ndims = len(x.shape)
325-
event_dim = ndims - transform.input_event_dim
344+
event_dim = ndims - transform.domain.event_dim
326345
x_ = x.view((-1,) + x.shape[event_dim:])
327346
n = x_.shape[0]
328347
# Reshape to squash batch dims to a single batch dim
329348
transform = reshape_transform(transform, x_.shape)
330349

331-
# 1. Transforms with 0 off-diagonal elements
332-
if transform.input_event_dim == 0:
350+
# 1. Transforms with unit jacobian
351+
if isinstance(transform, ReshapeTransform) or isinstance(transform.inv, ReshapeTransform):
352+
expected = x.new_zeros(x.shape[x.dim() - transform.domain.event_dim])
353+
expected = x.new_zeros(x.shape[x.dim() - transform.domain.event_dim])
354+
# 2. Transforms with 0 off-diagonal elements
355+
elif transform.domain.event_dim == 0:
333356
jac = jacobian(transform, x_)
334357
# assert off-diagonal elements are zero
335358
assert torch.allclose(jac, jac.diagonal().diag_embed())
336359
expected = jac.diagonal().abs().log().reshape(x.shape)
337-
# 2. Transforms with non-0 off-diagonal elements
360+
# 3. Transforms with non-0 off-diagonal elements
338361
else:
339362
if isinstance(transform, CorrCholeskyTransform):
340363
jac = jacobian(lambda x: tril_matrix_to_vec(transform(x), diag=-1), x_)
@@ -361,5 +384,88 @@ def test_jacobian(transform):
361384
assert torch.allclose(actual, expected, atol=1e-5)
362385

363386

387+
@pytest.mark.parametrize("event_dims",
388+
[(0,), (1,), (2, 3), (0, 1, 2), (1, 2, 0), (2, 0, 1)],
389+
ids=str)
390+
def test_compose_affine(event_dims):
391+
transforms = [AffineTransform(torch.zeros((1,) * e), 1, event_dim=e) for e in event_dims]
392+
transform = ComposeTransform(transforms)
393+
assert transform.codomain.event_dim == max(event_dims)
394+
assert transform.domain.event_dim == max(event_dims)
395+
396+
base_dist = Normal(0, 1)
397+
if transform.domain.event_dim:
398+
base_dist = base_dist.expand((1,) * transform.domain.event_dim)
399+
dist = TransformedDistribution(base_dist, transform.parts)
400+
assert dist.support.event_dim == max(event_dims)
401+
402+
base_dist = Dirichlet(torch.ones(5))
403+
if transform.domain.event_dim > 1:
404+
base_dist = base_dist.expand((1,) * (transform.domain.event_dim - 1))
405+
dist = TransformedDistribution(base_dist, transforms)
406+
assert dist.support.event_dim == max(1, max(event_dims))
407+
408+
409+
@pytest.mark.parametrize("batch_shape", [(), (6,), (5, 4)], ids=str)
410+
def test_compose_reshape(batch_shape):
411+
transforms = [ReshapeTransform((), ()),
412+
ReshapeTransform((2,), (1, 2)),
413+
ReshapeTransform((3, 1, 2), (6,)),
414+
ReshapeTransform((6,), (2, 3))]
415+
transform = ComposeTransform(transforms)
416+
assert transform.codomain.event_dim == 2
417+
assert transform.domain.event_dim == 2
418+
data = torch.randn(batch_shape + (3, 2))
419+
assert transform(data).shape == batch_shape + (2, 3)
420+
421+
dist = TransformedDistribution(Normal(data, 1), transforms)
422+
assert dist.batch_shape == batch_shape
423+
assert dist.event_shape == (2, 3)
424+
assert dist.support.event_dim == 2
425+
426+
427+
@pytest.mark.parametrize("sample_shape", [(), (7,)], ids=str)
428+
@pytest.mark.parametrize("transform_dim", [0, 1, 2])
429+
@pytest.mark.parametrize("base_batch_dim", [0, 1, 2])
430+
@pytest.mark.parametrize("base_event_dim", [0, 1, 2])
431+
@pytest.mark.parametrize("num_transforms", [0, 1, 2, 3])
432+
def test_transformed_distribution(base_batch_dim, base_event_dim, transform_dim,
433+
num_transforms, sample_shape):
434+
shape = torch.Size([2, 3, 4, 5])
435+
base_dist = Normal(0, 1)
436+
base_dist = base_dist.expand(shape[4 - base_batch_dim - base_event_dim:])
437+
if base_event_dim:
438+
base_dist = Independent(base_dist, base_event_dim)
439+
transforms = [AffineTransform(torch.zeros(shape[4 - transform_dim:]), 1),
440+
ReshapeTransform((4, 5), (20,)),
441+
ReshapeTransform((3, 20), (6, 10))]
442+
transforms = transforms[:num_transforms]
443+
transform = ComposeTransform(transforms)
444+
445+
# Check validation in .__init__().
446+
if base_batch_dim + base_event_dim < transform.domain.event_dim:
447+
with pytest.raises(ValueError):
448+
TransformedDistribution(base_dist, transforms)
449+
return
450+
d = TransformedDistribution(base_dist, transforms)
451+
452+
# Check sampling is sufficiently expanded.
453+
x = d.sample(sample_shape)
454+
assert x.shape == sample_shape + d.batch_shape + d.event_shape
455+
num_unique = len(set(x.reshape(-1).tolist()))
456+
assert num_unique >= 0.9 * x.numel()
457+
458+
# Check log_prob shape on full samples.
459+
log_prob = d.log_prob(x)
460+
assert log_prob.shape == sample_shape + d.batch_shape
461+
462+
# Check log_prob shape on partial samples.
463+
y = x
464+
while y.dim() > len(d.event_shape):
465+
y = y[0]
466+
log_prob = d.log_prob(y)
467+
assert log_prob.shape == d.batch_shape
468+
469+
364470
if __name__ == '__main__':
365471
pytest.main([__file__])

torch/distributions/binomial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def expand(self, batch_shape, _instance=None):
6767
def _new(self, *args, **kwargs):
6868
return self._param.new(*args, **kwargs)
6969

70-
@constraints.dependent_property(is_discrete=True)
70+
@constraints.dependent_property(is_discrete=True, event_dim=0)
7171
def support(self):
7272
return constraints.integer_interval(0, self.total_count)
7373

torch/distributions/categorical.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def expand(self, batch_shape, _instance=None):
7676
def _new(self, *args, **kwargs):
7777
return self._param.new(*args, **kwargs)
7878

79-
@constraints.dependent_property(is_discrete=True)
79+
@constraints.dependent_property(is_discrete=True, event_dim=0)
8080
def support(self):
8181
return constraints.integer_interval(0, self._num_events - 1)
8282

torch/distributions/constraint_registry.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,12 +160,16 @@ def _transform_to_real(constraint):
160160

161161
@biject_to.register(constraints.independent)
162162
def _biject_to_independent(constraint):
163-
return biject_to(constraint.base_constraint)
163+
base_transform = biject_to(constraint.base_constraint)
164+
return transforms.IndependentTransform(
165+
base_transform, constraint.reinterpreted_batch_ndims)
164166

165167

166168
@transform_to.register(constraints.independent)
167169
def _transform_to_independent(constraint):
168-
return transform_to(constraint.base_constraint)
170+
base_transform = transform_to(constraint.base_constraint)
171+
return transforms.IndependentTransform(
172+
base_transform, constraint.reinterpreted_batch_ndims)
169173

170174

171175
@biject_to.register(constraints.positive)

0 commit comments

Comments
 (0)