Implement reparameterized gradient for Gamma sampler#3978
Merged
apaszke merged 7 commits intopytorch:masterfrom Dec 11, 2017
Merged
Implement reparameterized gradient for Gamma sampler#3978apaszke merged 7 commits intopytorch:masterfrom
apaszke merged 7 commits intopytorch:masterfrom
Conversation
Collaborator
Author
|
CC: @apaszke and @martinjankowiak who has reviewed some of the math. |
This was referenced Dec 2, 2017
apaszke
reviewed
Dec 3, 2017
Contributor
apaszke
left a comment
There was a problem hiding this comment.
Looks good for the most part, but changes in variable.py need to be reverted. Haven’t reviewed the math.
| /** Computes a reparameterized gradient of a sample from a standard Gamma | ||
| distribution wrt the shape parameter alpha. | ||
| */ | ||
| TH_API double THRandom_standard_gamma_grad(double x, double alpha); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| return Variable(torch._C._standard_gamma(alpha.data)) | ||
| if not alpha.requires_grad: | ||
| return Variable(torch._C._standard_gamma(alpha.data)) | ||
| return _StandardGamma.apply(alpha) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| def standard_gamma(self, grad=None): | ||
| if grad is None: | ||
| return Variable(torch.standard_gamma(self.data)) | ||
| return Variable(torch.standard_gamma(self.data, grad), requires_grad=self.requires_grad) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
Collaborator
Author
|
Thanks for reviewing @apaszke and sorry for the slow response. All comments now addressed. |
Closed
ac92297 to
09ef89c
Compare
apaszke
approved these changes
Dec 11, 2017
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Closes #3813
This implements reparameterized gradient for
distributions.Gamma. The gradient is implemented by directly approximating the reparameterized gradient functiondx/dalphafollowing Knowles (2015). The approximation is accurate to within 0.5% relative error for a wide range of alphas.Derivation
First consider the
betavariable. Ifx ~ Gamma(alpha, beta)thenx / beta ~ Gamma(alpha, 1). Since division is already implemented in PyTorch, we can thus reduce our problem to computing a reparameterized gradient of a standard gammax ~ Gamma(alpha) = Gamma(alpha, 1)wrtalpha.This PR implements a function
standard_gamma_grad(x, alpha)that directly approximates the reparameterized gradient defined (for any continuous univariate distribution) asThis definition is used in the unit tests in
tests/test_distributions.py, which computed/dalpha cdf(x;alpha)via finite difference of thescipy.stats.gamma.cdf()function.The approximation is split into three regions:
xwe use a power series approximation ofcdf(x, alpha).Until
digamma()is implemented in PyTorch, we use a finite difference oflgamma().alphawe use the approximationPQ(u,v)is a rational function of order 2 in u and 3 in v.For complete derivation, see this Jupyter Notebook.