Skip to content

Implement torch.standard_gamma and distributions.Gamma#3841

Merged
apaszke merged 13 commits intopytorch:masterfrom
probtorch:random-gamma
Dec 2, 2017
Merged

Implement torch.standard_gamma and distributions.Gamma#3841
apaszke merged 13 commits intopytorch:masterfrom
probtorch:random-gamma

Conversation

@fritzo
Copy link
Copy Markdown
Collaborator

@fritzo fritzo commented Nov 22, 2017

Addresses #3813

This implements a torch.standard_gamma() random number generator and a distributions.Gamma distribution that implements the Gamma distribution. Note that this is named torch.standard_gamma to avoid confusion with the Gamma function that is already implemented as torch.lgamma.

We follow scipy in generating standard Gamma variables Gamma(alpha, 1) rather than fully-parameterized Gamma(alpha, beta) random variables for two reasons: (1) this partial parameterization makes it easier to implement reparameterized gradients, and (2) the community is split between the scale parameter theta and the rate parameter beta = 1/theta. This PR uses the beta parameter in distributions.Gamma, but remain agnostic in torch.standard_gamma(alpha).

Tested

  • Added deterministic tests of shape and .log_prob() method
  • Added a randomized test of .sample() method (also added a test for Normal.sample())

test_distributions.py runs in under 1 second.

@pytorchbot
Copy link
Copy Markdown
Collaborator

@fritzo, thanks for your PR! We identified @zdevito to be a potential reviewer.

@fritzo
Copy link
Copy Markdown
Collaborator Author

fritzo commented Nov 22, 2017

@colesbury Is there anything else I should do beyond writing tests?

Copy link
Copy Markdown
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment thread torch/distributions.py Outdated

def sample_n(self, n):
# cleanly expand float or Tensor or Variable parameters
def expand(v):

This comment was marked as off-topic.

Comment thread aten/src/ATen/Declarations.cwrap Outdated
- floating_point
backends:
- CPU
- CUDA

This comment was marked as off-topic.

@fritzo
Copy link
Copy Markdown
Collaborator Author

fritzo commented Nov 23, 2017

Ok I've added lots more boilerplate to expose torch.random_gamma. There are some still some numerical bugs.

@fritzo
Copy link
Copy Markdown
Collaborator Author

fritzo commented Nov 23, 2017

@apaszke @colesbury this is ready for another review. I believe it is correct now. Thanks for the help!

Comment thread torch/nn/init.py Outdated
return tensor.normal_(mean, std)


def random_gamma(tensor, alpha=1, beta=1):

This comment was marked as off-topic.

This comment was marked as off-topic.

Comment thread torch/distributions.py Outdated
else:
alpha = type(beta)(*beta.size()).fill_(alpha)
elif isinstance(beta, Number):
beta = type(alpha)(*alpha.size()).fill_(beta)

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

Comment thread test/test_distributions.py Outdated

self._check_log_prob(Gamma(alpha, beta), ref_log_prob)

# FIXME this fails due to bad numerics

This comment was marked as off-topic.

This comment was marked as off-topic.

Comment thread test/test_distributions.py Outdated
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)

This comment was marked as off-topic.

Comment thread test/test_distributions.py Outdated
import math
import numpy as np
import scipy.stats
import scipy.special

This comment was marked as off-topic.

Comment thread test/test_distributions.py Outdated
import math
import numpy as np
import scipy.stats
import scipy.special

This comment was marked as off-topic.

Comment thread test/test_distributions.py Outdated
def test_normal_sample(self):
self._set_rng_seed()
for mean in [-1.0, 0.0, 1.0]:
for std in [0.1, 1.0, 10.0]:

This comment was marked as off-topic.

Comment thread torch/distributions.py Outdated
else:
alpha = type(beta)(*beta.size()).fill_(alpha)
elif isinstance(beta, Number):
beta = type(alpha)(*alpha.size()).fill_(beta)

This comment was marked as off-topic.

@fritzo
Copy link
Copy Markdown
Collaborator Author

fritzo commented Nov 23, 2017

Thanks for the review @apaszke ! I've addressed all your comments.

@fritzo fritzo changed the title Implement torch.random_gamma and distributions.Gamma Implement torch.standard_gamma and distributions.Gamma Nov 26, 2017
@fritzo
Copy link
Copy Markdown
Collaborator Author

fritzo commented Nov 26, 2017

I've simplified this PR to implement torch.standard_gamma(alpha) rather than torch.random_gamma(alpha, beta) so as to make it easier to implement reparameterized gradients in a subsequent PR. These two relate by the simple equation

random_gamma(alpha, beta) = standard_gamma(alpha) / beta

@apaszke
Copy link
Copy Markdown
Contributor

apaszke commented Nov 26, 2017

I'm wondering if standard_gamma is distinct enough to not be confused with the regular gamma function. Couldn't we expose it as an internal method in _C and realize the sampling only through the torch.distributions API?

@fritzo
Copy link
Copy Markdown
Collaborator Author

fritzo commented Nov 27, 2017

Ok, I've addressed all review comments.

@apaszke
Copy link
Copy Markdown
Contributor

apaszke commented Nov 28, 2017

@pytorcbot add to whitelist

@apaszke apaszke closed this Nov 28, 2017
@apaszke apaszke reopened this Nov 28, 2017
@apaszke
Copy link
Copy Markdown
Contributor

apaszke commented Nov 28, 2017

I think this is good to go, but it might not be exposed in ATen. cc @zdevito @colesbury what needs to be done to have it there? Just add to Declaractions.yaml?

Comment thread torch/distributions.py Outdated
elif alpha_num and beta_num:
alpha, beta = torch.Tensor([alpha]), torch.Tensor([beta])
else:
alpha = alpha.expand_as(beta)

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@fritzo
Copy link
Copy Markdown
Collaborator Author

fritzo commented Nov 30, 2017

@apaszke How does the merge process work? Did I set back the merge process by adding a commit after pytorchbot whitelisting? (I'm eager to get this merged because 2 other PRs are stacked on top: probtorch#26 probtorch#28)

@apaszke
Copy link
Copy Markdown
Contributor

apaszke commented Dec 1, 2017

Yeah I think it was ready to merge, but now the broadcasting doesn't look correct to me (correct me if I'm wrong). It's all good to land once this is fixed

@fritzo
Copy link
Copy Markdown
Collaborator Author

fritzo commented Dec 1, 2017

Ok, I've replaced the .expand_as() with a check that the sizes are equal. This disallow-broadcasting behavior is already adopted by Normal, and will also be used by Beta.

@fritzo
Copy link
Copy Markdown
Collaborator Author

fritzo commented Dec 1, 2017

Any idea what's going on with the failed builds?

@apaszke apaszke merged commit 165d089 into pytorch:master Dec 2, 2017
@fritzo fritzo deleted the random-gamma branch December 13, 2017 01:56
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants