Implement torch.standard_gamma and distributions.Gamma#3841
Implement torch.standard_gamma and distributions.Gamma#3841apaszke merged 13 commits intopytorch:masterfrom
Conversation
|
@colesbury Is there anything else I should do beyond writing tests? |
|
|
||
| def sample_n(self, n): | ||
| # cleanly expand float or Tensor or Variable parameters | ||
| def expand(v): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| - floating_point | ||
| backends: | ||
| - CPU | ||
| - CUDA |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
Ok I've added lots more boilerplate to expose |
|
@apaszke @colesbury this is ready for another review. I believe it is correct now. Thanks for the help! |
| return tensor.normal_(mean, std) | ||
|
|
||
|
|
||
| def random_gamma(tensor, alpha=1, beta=1): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| else: | ||
| alpha = type(beta)(*beta.size()).fill_(alpha) | ||
| elif isinstance(beta, Number): | ||
| beta = type(alpha)(*alpha.size()).fill_(beta) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
|
||
| self._check_log_prob(Gamma(alpha, beta), ref_log_prob) | ||
|
|
||
| # FIXME this fails due to bad numerics |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| np.random.seed(seed) | ||
| torch.manual_seed(seed) | ||
| if torch.cuda.is_available(): | ||
| torch.cuda.manual_seed_all(seed) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| import math | ||
| import numpy as np | ||
| import scipy.stats | ||
| import scipy.special |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| import math | ||
| import numpy as np | ||
| import scipy.stats | ||
| import scipy.special |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| def test_normal_sample(self): | ||
| self._set_rng_seed() | ||
| for mean in [-1.0, 0.0, 1.0]: | ||
| for std in [0.1, 1.0, 10.0]: |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| else: | ||
| alpha = type(beta)(*beta.size()).fill_(alpha) | ||
| elif isinstance(beta, Number): | ||
| beta = type(alpha)(*alpha.size()).fill_(beta) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
Thanks for the review @apaszke ! I've addressed all your comments. |
|
I've simplified this PR to implement random_gamma(alpha, beta) = standard_gamma(alpha) / beta |
|
I'm wondering if |
|
Ok, I've addressed all review comments. |
|
@pytorcbot add to whitelist |
|
I think this is good to go, but it might not be exposed in ATen. cc @zdevito @colesbury what needs to be done to have it there? Just add to |
| elif alpha_num and beta_num: | ||
| alpha, beta = torch.Tensor([alpha]), torch.Tensor([beta]) | ||
| else: | ||
| alpha = alpha.expand_as(beta) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
@apaszke How does the merge process work? Did I set back the merge process by adding a commit after pytorchbot whitelisting? (I'm eager to get this merged because 2 other PRs are stacked on top: probtorch#26 probtorch#28) |
|
Yeah I think it was ready to merge, but now the broadcasting doesn't look correct to me (correct me if I'm wrong). It's all good to land once this is fixed |
|
Ok, I've replaced the |
|
Any idea what's going on with the failed builds? |
Addresses #3813
This implements a
torch.standard_gamma()random number generator and adistributions.Gammadistribution that implements the Gamma distribution. Note that this is namedtorch.standard_gammato avoid confusion with the Gamma function that is already implemented astorch.lgamma.We follow scipy in generating standard Gamma variables
Gamma(alpha, 1)rather than fully-parameterizedGamma(alpha, beta)random variables for two reasons: (1) this partial parameterization makes it easier to implement reparameterized gradients, and (2) the community is split between the scale parameterthetaand the rate parameterbeta = 1/theta. This PR uses thebetaparameter indistributions.Gamma, but remain agnostic intorch.standard_gamma(alpha).Tested
.log_prob()method.sample()method (also added a test forNormal.sample())test_distributions.pyruns in under 1 second.