Skip to content

Implement reparameterized gradient for Gamma sampler#3978

Merged
apaszke merged 7 commits intopytorch:masterfrom
probtorch:gamma-reparameterized
Dec 11, 2017
Merged

Implement reparameterized gradient for Gamma sampler#3978
apaszke merged 7 commits intopytorch:masterfrom
probtorch:gamma-reparameterized

Conversation

@fritzo
Copy link
Copy Markdown
Collaborator

@fritzo fritzo commented Dec 2, 2017

Closes #3813

This implements reparameterized gradient for distributions.Gamma. The gradient is implemented by directly approximating the reparameterized gradient function dx/dalpha following Knowles (2015). The approximation is accurate to within 0.5% relative error for a wide range of alphas.

Derivation

First consider the beta variable. If x ~ Gamma(alpha, beta) then x / beta ~ Gamma(alpha, 1). Since division is already implemented in PyTorch, we can thus reduce our problem to computing a reparameterized gradient of a standard gamma x ~ Gamma(alpha) = Gamma(alpha, 1) wrt alpha.

This PR implements a function standard_gamma_grad(x, alpha) that directly approximates the reparameterized gradient defined (for any continuous univariate distribution) as

                d/dalpha cdf(x; alpha)     d/dalpha cdf(x; alpha)
dx / dalpha = - ---------------------- = - ----------------------
                  d/dx cdf(x; alpha)           pdf(x; alpha)

This definition is used in the unit tests in tests/test_distributions.py, which compute d/dalpha cdf(x;alpha) via finite difference of the scipy.stats.gamma.cdf() function.

The approximation is split into three regions:

  • For small x we use a power series approximation of cdf(x, alpha).
    Until digamma() is implemented in PyTorch, we use a finite difference of lgamma().
  • For large alpha we use the approximation
    standard_gamma_grad(x, alpha) = sqrt(x/alpha)
    
  • For intermediate x,alpha we use a rational function approximation
    standard_gamma_grad(x, alpha) = exp(PQ(log(x / alpha), log(alpha)))
    
    where PQ(u,v) is a rational function of order 2 in u and 3 in v.

For complete derivation, see this Jupyter Notebook.

@fritzo fritzo changed the title Implement reparameterized gradient for random Gamma sampler Implement reparameterized gradient for Gamma sampler Dec 2, 2017
@fritzo
Copy link
Copy Markdown
Collaborator Author

fritzo commented Dec 2, 2017

CC: @apaszke and @martinjankowiak who has reviewed some of the math.

Copy link
Copy Markdown
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good for the most part, but changes in variable.py need to be reverted. Haven’t reviewed the math.

Comment thread aten/src/TH/THRandom.h Outdated
/** Computes a reparameterized gradient of a sample from a standard Gamma
distribution wrt the shape parameter alpha.
*/
TH_API double THRandom_standard_gamma_grad(double x, double alpha);

This comment was marked as off-topic.

This comment was marked as off-topic.

Comment thread torch/distributions.py
return Variable(torch._C._standard_gamma(alpha.data))
if not alpha.requires_grad:
return Variable(torch._C._standard_gamma(alpha.data))
return _StandardGamma.apply(alpha)

This comment was marked as off-topic.

Comment thread torch/autograd/variable.py Outdated
def standard_gamma(self, grad=None):
if grad is None:
return Variable(torch.standard_gamma(self.data))
return Variable(torch.standard_gamma(self.data, grad), requires_grad=self.requires_grad)

This comment was marked as off-topic.

@fritzo
Copy link
Copy Markdown
Collaborator Author

fritzo commented Dec 9, 2017

Thanks for reviewing @apaszke and sorry for the slow response. All comments now addressed.

@fritzo fritzo mentioned this pull request Dec 9, 2017
@fritzo fritzo force-pushed the gamma-reparameterized branch from ac92297 to 09ef89c Compare December 10, 2017 12:14
@apaszke apaszke merged commit 05ebd21 into pytorch:master Dec 11, 2017
@naesseth
Copy link
Copy Markdown

naesseth commented Dec 11, 2017

@fritzo @apaszke If you are OK with a (small) bias I believe using the shape augmentation trick in my paper and just ignoring the score function term will be much more accurate and efficient.

https://arxiv.org/abs/1610.05683

@fritzo fritzo deleted the gamma-reparameterized branch December 13, 2017 01:55
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Implement random_gamma() sampler (with gradients)

4 participants