Skip to content

Commit e2041ce

Browse files
neerajpradfacebook-github-bot
authored andcommitted
Fix docstring to clarify logits usage for multiclass case (#51053)
Summary: Fixes #50378. Additionally, this has some minor fixes: - [x] Fix mean for half-cauchy to return `inf` instead of `nan`. - [x] Fix constraints/support for the relaxed categorical distribution. Pull Request resolved: #51053 Reviewed By: heitorschueroff Differential Revision: D26077966 Pulled By: neerajprad fbshipit-source-id: ca0213baa9bbdbc661aebbb901ab5e7fded38a5f
1 parent 221d7d9 commit e2041ce

6 files changed

Lines changed: 31 additions & 16 deletions

File tree

test/distributions/test_distributions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1511,7 +1511,7 @@ def test_cauchy(self):
15111511
def test_halfcauchy(self):
15121512
scale = torch.ones(5, 5, requires_grad=True)
15131513
scale_1d = torch.ones(1, requires_grad=True)
1514-
self.assertTrue(is_all_nan(HalfCauchy(scale_1d).mean))
1514+
self.assertTrue(torch.isinf(HalfCauchy(scale_1d).mean).all())
15151515
self.assertEqual(HalfCauchy(scale_1d).variance, inf)
15161516
self.assertEqual(HalfCauchy(scale).sample().size(), (5, 5))
15171517
self.assertEqual(HalfCauchy(scale).sample((7,)).size(), (7, 5, 5))

torch/distributions/categorical.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,19 @@ class Categorical(Distribution):
1616
1717
Samples are integers from :math:`\{0, \ldots, K-1\}` where `K` is ``probs.size(-1)``.
1818
19-
If :attr:`probs` is 1-dimensional with length-`K`, each element is the relative
20-
probability of sampling the class at that index.
19+
If `probs` is 1-dimensional with length-`K`, each element is the relative probability
20+
of sampling the class at that index.
2121
22-
If :attr:`probs` is N-dimensional, the first N-1 dimensions are treated as a batch of
22+
If `probs` is N-dimensional, the first N-1 dimensions are treated as a batch of
2323
relative probability vectors.
2424
25-
.. note:: :attr:`probs` must be non-negative, finite and have a non-zero sum,
26-
and it will be normalized to sum to 1 along the last dimension.
25+
.. note:: The `probs` argument must be non-negative, finite and have a non-zero sum,
26+
and it will be normalized to sum to 1 along the last dimension. attr:`probs`
27+
will return this normalized value.
28+
The `logits` argument will be interpreted as unnormalized log probabilities
29+
and can therefore be any real number. It will likewise be normalized so that
30+
the resulting probabilities sum to 1 along the last dimension. attr:`logits`
31+
will return this normalized value.
2732
2833
See also: :func:`torch.multinomial`
2934
@@ -35,7 +40,7 @@ class Categorical(Distribution):
3540
3641
Args:
3742
probs (Tensor): event probabilities
38-
logits (Tensor): event log-odds
43+
logits (Tensor): event log probabilities (unnormalized)
3944
"""
4045
arg_constraints = {'probs': constraints.simplex,
4146
'logits': constraints.real_vector}

torch/distributions/half_cauchy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def scale(self):
4343

4444
@property
4545
def mean(self):
46-
return self.base_dist.mean
46+
return torch.full(self._extended_shape(), math.inf, dtype=self.scale.dtype, device=self.scale.device)
4747

4848
@property
4949
def variance(self):

torch/distributions/multinomial.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,13 @@ class Multinomial(Distribution):
1515
Note that :attr:`total_count` need not be specified if only :meth:`log_prob` is
1616
called (see example below)
1717
18-
.. note:: :attr:`probs` must be non-negative, finite and have a non-zero sum,
19-
and it will be normalized to sum to 1.
18+
.. note:: The `probs` argument must be non-negative, finite and have a non-zero sum,
19+
and it will be normalized to sum to 1 along the last dimension. attr:`probs`
20+
will return this normalized value.
21+
The `logits` argument will be interpreted as unnormalized log probabilities
22+
and can therefore be any real number. It will likewise be normalized so that
23+
the resulting probabilities sum to 1 along the last dimension. attr:`logits`
24+
will return this normalized value.
2025
2126
- :meth:`sample` requires a single shared `total_count` for all
2227
parameters and samples.
@@ -35,7 +40,7 @@ class Multinomial(Distribution):
3540
Args:
3641
total_count (int): number of trials
3742
probs (Tensor): event probabilities
38-
logits (Tensor): event log probabilities
43+
logits (Tensor): event log probabilities (unnormalized)
3944
"""
4045
arg_constraints = {'probs': constraints.simplex,
4146
'logits': constraints.real_vector}

torch/distributions/one_hot_categorical.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,13 @@ class OneHotCategorical(Distribution):
1111
1212
Samples are one-hot coded vectors of size ``probs.size(-1)``.
1313
14-
.. note:: :attr:`probs` must be non-negative, finite and have a non-zero sum,
15-
and it will be normalized to sum to 1.
14+
.. note:: The `probs` argument must be non-negative, finite and have a non-zero sum,
15+
and it will be normalized to sum to 1 along the last dimension. attr:`probs`
16+
will return this normalized value.
17+
The `logits` argument will be interpreted as unnormalized log probabilities
18+
and can therefore be any real number. It will likewise be normalized so that
19+
the resulting probabilities sum to 1 along the last dimension. attr:`logits`
20+
will return this normalized value.
1621
1722
See also: :func:`torch.distributions.Categorical` for specifications of
1823
:attr:`probs` and :attr:`logits`.
@@ -25,7 +30,7 @@ class OneHotCategorical(Distribution):
2530
2631
Args:
2732
probs (Tensor): event probabilities
28-
logits (Tensor): event log probabilities
33+
logits (Tensor): event log probabilities (unnormalized)
2934
"""
3035
arg_constraints = {'probs': constraints.simplex,
3136
'logits': constraints.real_vector}

torch/distributions/relaxed_categorical.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class ExpRelaxedCategorical(Distribution):
2121
Args:
2222
temperature (Tensor): relaxation temperature
2323
probs (Tensor): event probabilities
24-
logits (Tensor): the log probability of each event.
24+
logits (Tensor): unnormalized log probability for each event
2525
2626
[1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables
2727
(Maddison et al, 2017)
@@ -101,7 +101,7 @@ class RelaxedOneHotCategorical(TransformedDistribution):
101101
Args:
102102
temperature (Tensor): relaxation temperature
103103
probs (Tensor): event probabilities
104-
logits (Tensor): the log probability of each event.
104+
logits (Tensor): unnormalized log probability for each event
105105
"""
106106
arg_constraints = {'probs': constraints.simplex,
107107
'logits': constraints.real_vector}

0 commit comments

Comments
 (0)