Skip to content

Incorrect shape from torch.distributions.kl.kl_divergence #34859

@emailweixu

Description

@emailweixu

🐛 Bug

To Reproduce

Steps to reproduce the behavior:
Run the following code

import torch
import torch.distributions as td

scale=torch.ones(2,3)
loc=torch.zeros(2,3)
normal = td.Normal(loc=loc, scale=scale)
diag_normal = td.Independent(normal, reinterpreted_batch_ndims=1)
trans_dist = td.TransformedDistribution(diag_normal, transforms=td.AffineTransform(loc=0., scale=2.))
print(td.kl.kl_divergence(trans_dist, trans_dist), "Got incorrect shape")
print(td.kl.kl_divergence(diag_normal, diag_normal), "Correct shape")

Expected behavior

In the above code, the expected shape from the first kl_divergence() should be [2]. Instead, I got [].

Environment

ollecting environment information...
PyTorch version: 1.4.0
Is debug build: No
CUDA used to build PyTorch: 10.1

OS: Ubuntu 18.04.2 LTS
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
CMake version: version 3.10.2

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 10.1.168
GPU models and configuration:
GPU 0: GeForce RTX 2080 Ti
GPU 1: GeForce RTX 2080 Ti

Nvidia driver version: 440.33.01
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5

Versions of relevant libraries:
[pip3] numpy==1.18.1
[pip3] torch==1.4.0
[pip3] torchvision==0.5.0
[conda] Could not collect

It seems to be the bug is caused by the following code. It seems to me the commented line is correct.

# extra_event_dim = len(p.event_shape) - len(p.base_dist.event_shape)
extra_event_dim = len(p.event_shape)

cc @vincentqb @fritzo @neerajprad @alicanb @vishwakftw

Metadata

Metadata

Assignees

Labels

module: distributionsRelated to torch.distributionstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions