🐛 Bug
To Reproduce
Steps to reproduce the behavior:
Run the following code
import torch
import torch.distributions as td
scale=torch.ones(2,3)
loc=torch.zeros(2,3)
normal = td.Normal(loc=loc, scale=scale)
diag_normal = td.Independent(normal, reinterpreted_batch_ndims=1)
trans_dist = td.TransformedDistribution(diag_normal, transforms=td.AffineTransform(loc=0., scale=2.))
print(td.kl.kl_divergence(trans_dist, trans_dist), "Got incorrect shape")
print(td.kl.kl_divergence(diag_normal, diag_normal), "Correct shape")
Expected behavior
In the above code, the expected shape from the first kl_divergence() should be [2]. Instead, I got [].
Environment
ollecting environment information...
PyTorch version: 1.4.0
Is debug build: No
CUDA used to build PyTorch: 10.1
OS: Ubuntu 18.04.2 LTS
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
CMake version: version 3.10.2
Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 10.1.168
GPU models and configuration:
GPU 0: GeForce RTX 2080 Ti
GPU 1: GeForce RTX 2080 Ti
Nvidia driver version: 440.33.01
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
Versions of relevant libraries:
[pip3] numpy==1.18.1
[pip3] torch==1.4.0
[pip3] torchvision==0.5.0
[conda] Could not collect
It seems to be the bug is caused by the following code. It seems to me the commented line is correct.
|
# extra_event_dim = len(p.event_shape) - len(p.base_dist.event_shape) |
|
extra_event_dim = len(p.event_shape) |
cc @vincentqb @fritzo @neerajprad @alicanb @vishwakftw
🐛 Bug
To Reproduce
Steps to reproduce the behavior:
Run the following code
Expected behavior
In the above code, the expected shape from the first kl_divergence() should be
[2]. Instead, I got[].Environment
ollecting environment information...
PyTorch version: 1.4.0
Is debug build: No
CUDA used to build PyTorch: 10.1
OS: Ubuntu 18.04.2 LTS
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
CMake version: version 3.10.2
Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 10.1.168
GPU models and configuration:
GPU 0: GeForce RTX 2080 Ti
GPU 1: GeForce RTX 2080 Ti
Nvidia driver version: 440.33.01
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
Versions of relevant libraries:
[pip3] numpy==1.18.1
[pip3] torch==1.4.0
[pip3] torchvision==0.5.0
[conda] Could not collect
It seems to be the bug is caused by the following code. It seems to me the commented line is correct.
pytorch/torch/distributions/kl.py
Lines 438 to 439 in 471ddac
cc @vincentqb @fritzo @neerajprad @alicanb @vishwakftw