Skip to content

Categorical logits argument is treated as log probabilities #50378

@JohnReid

Description

@JohnReid

🐛 Bug

The logit argument to torch.distributions.Categorical is not treated as logits, rather as log probabilities

To Reproduce

import torch as t
from torch.distributions import Categorical
probs = t.tensor([.3, .3, .399, .001])
y_test = t.tensor([0, 1, 2, 1, 2, 2, 3])
print(t.exp(Categorical(probs=probs).log_prob(y_test)))
print(t.exp(Categorical(logits=t.log(probs)).log_prob(y_test)))
print(t.exp(Categorical(logits=t.logit(probs)).log_prob(y_test)))

produces

tensor([0.3000, 0.3000, 0.3990, 0.3000, 0.3990, 0.3990, 0.0010])
tensor([0.3000, 0.3000, 0.3990, 0.3000, 0.3990, 0.3990, 0.0010])
tensor([0.2816, 0.2816, 0.4362, 0.2816, 0.4362, 0.4362, 0.0007])

Expected behavior

I expect the last line to be the same as the first and the second line to be different from the first. Log probabilities are not logits.

logit(p) = log(p / (1 - p)) = log(p) - log(1 - p) != log(p)

Environment

Collecting environment information...
PyTorch version: 1.7.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 10.14.6 (x86_64)
GCC version: Could not collect
Clang version: Could not collect
CMake version: version 3.19.2

Python version: 3.7 (64-bit runtime)
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] botorch==0.3.1
[pip3] gpytorch==1.2.0
[pip3] numpy==1.17.3
[pip3] torch==1.7.1
[pip3] torch-struct==0.4
[pip3] torchvision==0.8.2
[conda] Could not collect

cc @fritzo @neerajprad @alicanb @vishwakftw @nikitaved

Metadata

Metadata

Assignees

Labels

module: distributionsRelated to torch.distributionstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions