Implement digamma#3955
Conversation
fritzo
left a comment
There was a problem hiding this comment.
Thanks for implementing this!
|
|
||
| def test_digamma(self): | ||
| y = torch.Tensor([-10, 0]) | ||
| x = torch.Tensor([-0.1, 3, 999]) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
fwiw I'm adding the ability to write efficient pointwise native functions in ATen directly. |
|
|
||
| - name: lgamma(Tensor self) | ||
| self: not_implemented("lgamma") | ||
| self: grad * digamma(self) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
apaszke
left a comment
There was a problem hiding this comment.
Looks good. We should resolve the licensing issues before merging this.
| return THTensor_(digamma_one)(1 - x) - PI / tan(PI * x); | ||
| } | ||
|
|
||
| // Push x to be >= 10 |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| } | ||
|
|
||
| /* | ||
| * Algorithm adapted from Cephes |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| tensor = tensor.unsqueeze(1) | ||
| self.assertEqual(tensor.var(0)[0], 0.03125) | ||
|
|
||
| def test_digamma(self): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
|
||
| - name: lgamma(Tensor self) | ||
| self: not_implemented("lgamma") | ||
| self: grad * digamma(self) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
@zou3519 can I help with this at all? Now that you've done the hard part, I'm happy to resolve merge conflicts and add tests. (We're looking forward to using this in Pyro, and we already have a wrapper to use |
|
@fritzo I haven't reached out to the author about the licensing yet. I'll do that tomorrow and we'll see how that goes :) |
|
I see, thanks for letting me know! |
|
Any update on the license status? |
|
Sent an email a few days ago to the author, haven't heard back yet |
|
If we still haven't heard from the author, I could throw together a little PR that exposes the makeshift finite-difference implementation of digamma that we're already using internally in a few places. This would at least unblock users of Gamma, Beta, and Dirichlet distributions.
|
|
I think merging an approximation is good for now |
|
Hmm I looked into it but it appears to be quite complex now that our functions using |
|
@fritzo sure, I'll take a look at. |
|
Licensing issues resolved, PR rebased. I was told that "You are welcome to modify and distribute the library under BSD license." so I've added the original copyright notices into the code as comments. |
|
That's great news, @zou3519 ! |
Added test to check digamma float vs double.
TestCuda.test_digamma checks the CUDA {float, double} implementation
against the CPU {float, double} implementation.
|
Any news on this? It seems only the GPU build failed on Windows last November, and that only because it ran out of memory. |
|
@wranai A simpler implementation of digamma and trigamma is already in master. Try |
Fixes pytorch#6190. This is a rebase of pytorch#3955 with some tweaks for better performance around poles. The code is ported over from cephes with permission. By itself, the cephes code returns inf for the poles. For better performance around the poles with float32, one intermediate step is always computed with double precision, regardless of dtype. This step does `PI / tan(PI * input)`. This is necessary because small (1e-6) rounding errors for the inputs to tan have strong effects on the output (ie, the derivative of tan is very large at some points).
* More precise digamma Fixes #6190. This is a rebase of #3955 with some tweaks for better performance around poles. The code is ported over from cephes with permission. By itself, the cephes code returns inf for the poles. For better performance around the poles with float32, one intermediate step is always computed with double precision, regardless of dtype. This step does `PI / tan(PI * input)`. This is necessary because small (1e-6) rounding errors for the inputs to tan have strong effects on the output (ie, the derivative of tan is very large at some points). * Replace usages of finite-differences digamma with newly implemented digamma * Better behavior near and at poles * ScalarConvert -> scalar_cast for readability
* More precise digamma Fixes pytorch#6190. This is a rebase of pytorch#3955 with some tweaks for better performance around poles. The code is ported over from cephes with permission. By itself, the cephes code returns inf for the poles. For better performance around the poles with float32, one intermediate step is always computed with double precision, regardless of dtype. This step does `PI / tan(PI * input)`. This is necessary because small (1e-6) rounding errors for the inputs to tan have strong effects on the output (ie, the derivative of tan is very large at some points). * Replace usages of finite-differences digamma with newly implemented digamma * Better behavior near and at poles * ScalarConvert -> scalar_cast for readability
Implements
torch.digammaandtensor.digamma()as per #678cc @fritzo
Test Plan
New unit tests. cpu unit tests test some edge cases, gpu unit test compares output of cpu to gpu