Skip to content

Implement Reparameterized version of Gamma#26

Closed
fritzo wants to merge 2 commits intoupstreamfrom
gamma-reparameterized
Closed

Implement Reparameterized version of Gamma#26
fritzo wants to merge 2 commits intoupstreamfrom
gamma-reparameterized

Conversation

@fritzo
Copy link
Copy Markdown

@fritzo fritzo commented Nov 26, 2017

Fixes #25

DO NOT MERGE. This PR will be moved to the pytorch org after pytorch#3841 is merged.

This implements reparameterized gradient for distributions.Gamma. The gradient is implemented by directly approximating the reparameterized gradient function dx/dalpha following Knowles (2015). The approximation is accurate to within 1% relative error for a wide range of alphas.

Derivation

Note that if x ~ Gamma(alpha, beta) then x / beta ~ Gamma(alpha, 1). Since division is already implemented in PyTorch, we can thus reduce our problem to computing a reparameterized gradient of a standard gamma x ~ Gamma(alpha) = Gamma(alpha, 1) wrt alpha.

This PR implements a function standard_gamma_grad(x, alpha) that directly approximates the reparameterized gradient defined (for any continuous univariate distribution) as

                d/dalpha cdf(x; alpha)     d/dalpha cdf(x; alpha)
dx / dalpha = - ---------------------- = - ----------------------
                  d/dx cdf(x; alpha)           pdf(x; alpha)

This definition is used in the unit tests in tests/test_distributions.py, which compute d/dalpha cdf(x;alpha) via finite difference of the scipy.stats.gamma.cdf() function.

The approximation is split into three regions:

  • For x < 0.001 we differentiate the power-law approximation from Knowles (2015)
    cdf(x; alpha)  \approx  x**alpha / (alpha * Gamma(alpha))
    standard_gamma_grad(x, alpha) = -x/alpha * (log(x) - 1/alpha - digamma(alpha))
    
    Until digamma() is implemented in PyTorch, we use a finite difference of lgamma().
  • For alpha > 30 we use the approximation
    standard_gamma_grad(x, alpha) = sqrt(x/alpha)
    
  • For intermediate x,alpha we use a rational function approximation
    standard_gamma_grad(x, alpha) = exp(PQ(log(x / alpha), log(alpha)))
    
    where PQ(u,v) is a rational function of order 2 in u and 3 in v. This was trained using least squares minimizing squared relative error on ~20000 samples drawn from
    alpha ~ log_uniform(1e-5, 1e2)
    x ~ Gamma(alpha)
    

For complete derivation, see this Jupyter Notebook.

@fritzo fritzo added the WIP label Nov 26, 2017
@fritzo
Copy link
Copy Markdown
Author

fritzo commented Nov 26, 2017

@apaszke Any advice on autograd plumbing? (I'll send this PR to pytorch/pytorch after pytorch#3841 is merged)

Copy link
Copy Markdown

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This mostly looks good, but needs a few tweaks

Comment thread torch/distributions.py
return -((value - self.mean) ** 2) / (2 * var) - log_std - math.log(math.sqrt(2 * math.pi))


class _StandardGamma(Function):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an old-style autograd function. You should write it as shown in the docs.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't able to store the gradient via ctx.save_for_backward(grad). Is there a new-style way to save an intermediate that is neither an input nor an output?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can't do that because that can only be done with inputs/outputs. Just do ctx.grad = grad.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment thread aten/src/TH/THRandom.c Outdated

// This is identical to THRandom_standard_gamma but also stores the
// reparameterized gradient wrt alpha in grad_alpha.
double THRandom_standard_gamma_with_grad(THGenerator *_generator, double alpha,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to compute the grad during forward, because you'd need to replay all the control flow here otherwise, right?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. This is a complicated gradient due to the rejection sampler. To reproduce the computation, would need to store about 10x more state.

default: THPDefaultGenerator->cdata
kwarg_only: True
- THTensor* alpha
- cname: standard_gamma_alpha_with_grad
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's bette to expose this as an internal method instead of a new overload (think torch._C._standard_gamma_alpha_with_grad)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you point me to an example internal method whose plumbing I can copy? I'm a little lost here.

def standard_gamma(self, grad=None):
if grad is None:
return Variable(torch.standard_gamma(self.data))
return Variable(torch.standard_gamma(self.data, grad), requires_grad=self.requires_grad)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should just use the Function subclass you implemented here.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the help! Should I embed that subclass in torch.autograd.variable, or should torch.distributions be the canonical interface for standard_gamma() for Variables?

@fritzo
Copy link
Copy Markdown
Author

fritzo commented Nov 26, 2017

@alicanb Would you be up for reviewing the statistical aspects of this PR?

@alicanb
Copy link
Copy Markdown
Collaborator

alicanb commented Nov 26, 2017 via email

Comment thread aten/src/TH/THRandom.c Outdated
const double accept_grad = d_grad * e + d * e_grad;
const double dv = d * v;
const double dv_grad = d_grad * v + d * v_grad // Pathwise part.
+ (dv - alpha) * accept_grad; // Acceptance part.
Copy link
Copy Markdown
Author

@fritzo fritzo Nov 26, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the trickiest part. This roughly follows Naesseth et al. (2017) but computes an exact reparameterized gradient rather than approximating (i.e. if you drew identical samples from an inverse CDF sampler Knowles (2015) and this cheaper sampler, the gradients would be identical). Naesseth et al. seem to be missing the analytical baseline alpha in (dv - alpha) which is easy to compute. Note that accept = log(acceptance ratio), so we're using the log trick in multiplying by accept_grad (and avoiding an expensive exp()).

Comment thread aten/src/TH/generic/THTensorRandom.c Outdated
real*const alpha_data = THTensor_(data)(alpha);
real*const saved_u_data = THTensor_(data)(saved_u);
real*const saved_x_data = THTensor_(data)(saved_x);
for(int64_t i = 0, numel = THTensor_(nElement)(alpha); i < numel; ++i) {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason why you don't use OpenMP here?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's difficult to parallelize because THGenerator *gen is stateful. However this _fwd() step is much cheaper than the _bwd() step.

Comment thread torch/distributions.py Outdated
@staticmethod
def backward(ctx, grad_output):
return grad_output * Variable(ctx.saved_grad)
alpha, = ctx.saved_variables
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This backward isn't really differentiable twice. Can you mark it @once_differentiable and use ctx.saved_tensors? grad_output will become a tensor too.

- arg: THTensor* output
output: True
- THTensor* alpha
- THTensor* saved_u
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice if we could skip these outputs if we knew an op won't be differentiated. But it's fine as is, and we can fix that later.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Already done :-) standard_gamma() is a simplified version of standard_gamma_fwd() where the unused outputs are discarded.

@alicanb
Copy link
Copy Markdown
Collaborator

alicanb commented Nov 28, 2017

Do you know the difference between torch.set_rng_state and torch.manual_seed? I tried torch.manual_seed while testing normal but couldn't get it working so I used set_rng_state

@apaszke
Copy link
Copy Markdown

apaszke commented Nov 29, 2017

They are pretty much equivalent.

torch.manual_seed(2)
s = torch.get_rng_state()
print(torch.randn(1))
# Same
torch.manual_seed(2)
print(torch.randn(1)) 
# Same
torch.set_rng_state(s)
print(torch.randn(1)) 

The benefit of manual_seed is that it also seeds the GPU, while set_rng_state only sets it for the CPU generator.

Comment thread test/test_distributions.py Outdated
@@ -163,6 +166,48 @@ def test_gamma_sample(self):
scipy.stats.gamma(alpha, scale=1 / beta),
Copy link
Copy Markdown
Collaborator

@alicanb alicanb Nov 29, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 / beta throws TypeError on python3. 1->1.0 fixes it

Comment thread test/test_distributions.py Outdated
alphas = Variable(torch.Tensor([alpha]), requires_grad=True)
betas = Variable(torch.Tensor([beta]), requires_grad=True)
self._check_sampler_sampler(Gamma(alphas, betas),
scipy.stats.gamma(alpha, scale=1 / beta),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

@fritzo fritzo force-pushed the gamma-reparameterized branch 2 times, most recently from 57a96b1 to 69df177 Compare November 29, 2017 20:06
@fritzo
Copy link
Copy Markdown
Author

fritzo commented Nov 29, 2017

Ok, I've changed algorithms to now directly approximate the reparameterized gradient. This achieves simpler code, cheaper computation, and more accurate gradients (they are no longer stochastic for alpha < 1).

@fritzo
Copy link
Copy Markdown
Author

fritzo commented Nov 29, 2017

@tbrx @jwvdm I've added a Jupyter notebook and some explanation in the PR description. Let me know if I can answer any other questions. Thanks for offering to review!

@jwvdm
Copy link
Copy Markdown
Member

jwvdm commented Nov 29, 2017

Thanks @fritzo – I'll plan on taking some time to review on Fri.

@fritzo fritzo force-pushed the gamma-reparameterized branch from f74bfac to 9c7694f Compare December 2, 2017 00:48
@fritzo fritzo changed the base branch from random-gamma to upstream December 2, 2017 00:52
@fritzo fritzo force-pushed the gamma-reparameterized branch from d6c1d30 to 6021328 Compare December 2, 2017 00:55
@fritzo
Copy link
Copy Markdown
Author

fritzo commented Dec 2, 2017

Moving to pytorch#3978

@fritzo fritzo closed this Dec 2, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants