Improve numerical stability of torch.distributions.wishart.Wishart#72059
Improve numerical stability of torch.distributions.wishart.Wishart#72059nonconvexopt wants to merge 26 commits intopytorch:masterfrom
torch.distributions.wishart.Wishart#72059Conversation
CI Flow Status⚛️ CI FlowRuleset - Version:
|
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit 4651368 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
neerajprad
left a comment
There was a problem hiding this comment.
Some tests are failing. Note that you should be able to run your test locally using:
pytest test/distributions/test_distributions.py -k {name of test}
You can build locally with pytest setup.py build. See this for details.
|
|
||
| # Implemented Sampling using Bartlett decomposition | ||
| noise = self._dist_chi2.rsample(sample_shape).sqrt().diag_embed(dim1=-2, dim2=-1) | ||
| noise = _clamp_with_eps( |
There was a problem hiding this comment.
I suppose this is the important change in this PR. Since there are some unrelated changes as well, could you add the specific change for better numerical stability in the PR description so that it is available for future reference?
There was a problem hiding this comment.
Thank you for the feedback. That would be great. I will summarize the Key modifications in the PR description.
There was a problem hiding this comment.
Note that Chi2 (inherits from Gamma) already clamps based on .tiny. See
pytorch/torch/distributions/gamma.py
Line 62 in a482aeb
Also, can we detach like in the example above so that its not recorded in autograd?
There was a problem hiding this comment.
I suppose this is adding some jitter to the diagonal elements, but is this high enough (I have seen 1e-6 or 1e-8 for float/double)?
Thank you for the detailed points. I understand that it is not a neat way.
As you know, if we calculate reciprocal of float32 value clamped with tiny=1.17549e-38, we get numbers upper bounded by the exponent e^38 in base 10. But we get the number upper bounded by exponent e^7 if we clamp with eps=1e-7. Thus, I thought we need to clamp the value again to provide more stable code. Higher value might be better. I just adapted the information of the torch.finfo to make implementation to be seemed more principled.
There was a problem hiding this comment.
Also, can we detach like in the example above so that its not recorded in autograd?
I will detach it if you think it is needed. But, I don't understand the intuition. Doesn't it affect .rsample()? I will check it out, too.
There was a problem hiding this comment.
You are right, this will be used downstream so we shouldn't detach! Sorry about my hasty and misleading comment.
There was a problem hiding this comment.
Thank you for the explanation. I think you found a potential issue in torch.distributions.gamma.Gamma. I wonder why there is a .detach either! Maybe we can fix it.
I will try to utilize my local environment. I appreciate your kindness. |
|
@neerajprad has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
It is my pleasure. Thank you for the review. |
…72059) Summary: Maintanance of #70377 Multiple modifications of the merged initial implementation of Wishart distribution. cc neerajprad Key modifications: - `torch/distributions/wishart.py`: Clamp (Clip) float type values to calculate reciprocal in numerically stable manner, by using the `eps` value paired to each `torch.dtype` - `test/distributions/test_distributions.py`: Test Wishart distribution implementation in numerically unstable zones, i.e `df` values are at `ndim - 1 < df < ndim` where `ndim` is the one dimenstion of the Wishart parameter & sample matrix. Pull Request resolved: #72059 Reviewed By: H-Huang Differential Revision: D34245091 Pulled By: neerajprad fbshipit-source-id: 1cd653c1d5c663346433e84fd0bbe2e590790908
|
Hey @nonconvexopt. |
|
Reverted this PR as it broke MacOS CI, see https://github.com/pytorch/pytorch/runs/5221143435?check_suite_focus=true |
|
This pull request has been reverted by 0b117a3. To re-land this change, please open another pull request, assignthe same reviewers, fix the CI failures that caused the revert and make sure that the failing CI runs on the PR by applying the proper ciflow label (e.g., ciflow/trunk). |
|
Here's the test failure for convenience: @malfet: Are these tests run after merging? I see that the CI was green on this PR, and I want to better understand how to check for issues like this before I merge. |
|
@neerajprad due to the limited capacity, MacOS tests are not run on PRs by default, one can opt in by adding |
|
I will try to fix the issue. |
…(#72059) Summary: Maintanance of pytorch/pytorch#70377 Multiple modifications of the merged initial implementation of Wishart distribution. cc neerajprad Key modifications: - `torch/distributions/wishart.py`: Clamp (Clip) float type values to calculate reciprocal in numerically stable manner, by using the `eps` value paired to each `torch.dtype` - `test/distributions/test_distributions.py`: Test Wishart distribution implementation in numerically unstable zones, i.e `df` values are at `ndim - 1 < df < ndim` where `ndim` is the one dimenstion of the Wishart parameter & sample matrix. Pull Request resolved: pytorch/pytorch#72059 Reviewed By: H-Huang Differential Revision: D34245091 Pulled By: neerajprad fbshipit-source-id: 1cd653c1d5c663346433e84fd0bbe2e590790908 (cherry picked from commit ef1da3ba465247f5777c3c40a90b96955c4281d0)
…(#72059) Summary: Maintanance of pytorch/pytorch#70377 Multiple modifications of the merged initial implementation of Wishart distribution. cc neerajprad Key modifications: - `torch/distributions/wishart.py`: Clamp (Clip) float type values to calculate reciprocal in numerically stable manner, by using the `eps` value paired to each `torch.dtype` - `test/distributions/test_distributions.py`: Test Wishart distribution implementation in numerically unstable zones, i.e `df` values are at `ndim - 1 < df < ndim` where `ndim` is the one dimenstion of the Wishart parameter & sample matrix. Pull Request resolved: pytorch/pytorch#72059 Reviewed By: H-Huang Differential Revision: D34245091 Pulled By: neerajprad fbshipit-source-id: 1cd653c1d5c663346433e84fd0bbe2e590790908 (cherry picked from commit ef1da3ba465247f5777c3c40a90b96955c4281d0)
…(#72059) Summary: Maintanance of pytorch/pytorch#70377 Multiple modifications of the merged initial implementation of Wishart distribution. cc neerajprad Key modifications: - `torch/distributions/wishart.py`: Clamp (Clip) float type values to calculate reciprocal in numerically stable manner, by using the `eps` value paired to each `torch.dtype` - `test/distributions/test_distributions.py`: Test Wishart distribution implementation in numerically unstable zones, i.e `df` values are at `ndim - 1 < df < ndim` where `ndim` is the one dimenstion of the Wishart parameter & sample matrix. Pull Request resolved: pytorch/pytorch#72059 Reviewed By: H-Huang Differential Revision: D34245091 Pulled By: neerajprad fbshipit-source-id: 1cd653c1d5c663346433e84fd0bbe2e590790908 (cherry picked from commit ef1da3ba465247f5777c3c40a90b96955c4281d0)
…(#72059) Summary: Maintanance of pytorch/pytorch#70377 Multiple modifications of the merged initial implementation of Wishart distribution. cc neerajprad Key modifications: - `torch/distributions/wishart.py`: Clamp (Clip) float type values to calculate reciprocal in numerically stable manner, by using the `eps` value paired to each `torch.dtype` - `test/distributions/test_distributions.py`: Test Wishart distribution implementation in numerically unstable zones, i.e `df` values are at `ndim - 1 < df < ndim` where `ndim` is the one dimenstion of the Wishart parameter & sample matrix. Pull Request resolved: pytorch/pytorch#72059 Reviewed By: H-Huang Differential Revision: D34245091 Pulled By: neerajprad fbshipit-source-id: 1cd653c1d5c663346433e84fd0bbe2e590790908 (cherry picked from commit ef1da3ba465247f5777c3c40a90b96955c4281d0)
…(#72059) Summary: Maintanance of pytorch/pytorch#70377 Multiple modifications of the merged initial implementation of Wishart distribution. cc neerajprad Key modifications: - `torch/distributions/wishart.py`: Clamp (Clip) float type values to calculate reciprocal in numerically stable manner, by using the `eps` value paired to each `torch.dtype` - `test/distributions/test_distributions.py`: Test Wishart distribution implementation in numerically unstable zones, i.e `df` values are at `ndim - 1 < df < ndim` where `ndim` is the one dimenstion of the Wishart parameter & sample matrix. Pull Request resolved: pytorch/pytorch#72059 Reviewed By: H-Huang Differential Revision: D34245091 Pulled By: neerajprad fbshipit-source-id: 1cd653c1d5c663346433e84fd0bbe2e590790908 (cherry picked from commit ef1da3ba465247f5777c3c40a90b96955c4281d0)
…(#72059) Summary: Maintanance of pytorch/pytorch#70377 Multiple modifications of the merged initial implementation of Wishart distribution. cc neerajprad Key modifications: - `torch/distributions/wishart.py`: Clamp (Clip) float type values to calculate reciprocal in numerically stable manner, by using the `eps` value paired to each `torch.dtype` - `test/distributions/test_distributions.py`: Test Wishart distribution implementation in numerically unstable zones, i.e `df` values are at `ndim - 1 < df < ndim` where `ndim` is the one dimenstion of the Wishart parameter & sample matrix. Pull Request resolved: pytorch/pytorch#72059 Reviewed By: H-Huang Differential Revision: D34245091 Pulled By: neerajprad fbshipit-source-id: 1cd653c1d5c663346433e84fd0bbe2e590790908 (cherry picked from commit ef1da3ba465247f5777c3c40a90b96955c4281d0)
…(#72059) Summary: Maintanance of pytorch/pytorch#70377 Multiple modifications of the merged initial implementation of Wishart distribution. cc neerajprad Key modifications: - `torch/distributions/wishart.py`: Clamp (Clip) float type values to calculate reciprocal in numerically stable manner, by using the `eps` value paired to each `torch.dtype` - `test/distributions/test_distributions.py`: Test Wishart distribution implementation in numerically unstable zones, i.e `df` values are at `ndim - 1 < df < ndim` where `ndim` is the one dimenstion of the Wishart parameter & sample matrix. Pull Request resolved: pytorch/pytorch#72059 Reviewed By: H-Huang Differential Revision: D34245091 Pulled By: neerajprad fbshipit-source-id: 1cd653c1d5c663346433e84fd0bbe2e590790908 (cherry picked from commit ef1da3ba465247f5777c3c40a90b96955c4281d0)
…(#72059) Summary: Maintanance of pytorch/pytorch#70377 Multiple modifications of the merged initial implementation of Wishart distribution. cc neerajprad Key modifications: - `torch/distributions/wishart.py`: Clamp (Clip) float type values to calculate reciprocal in numerically stable manner, by using the `eps` value paired to each `torch.dtype` - `test/distributions/test_distributions.py`: Test Wishart distribution implementation in numerically unstable zones, i.e `df` values are at `ndim - 1 < df < ndim` where `ndim` is the one dimenstion of the Wishart parameter & sample matrix. Pull Request resolved: pytorch/pytorch#72059 Reviewed By: H-Huang Differential Revision: D34245091 Pulled By: neerajprad fbshipit-source-id: 1cd653c1d5c663346433e84fd0bbe2e590790908 (cherry picked from commit ef1da3ba465247f5777c3c40a90b96955c4281d0)
…(#72059) Summary: Maintanance of pytorch/pytorch#70377 Multiple modifications of the merged initial implementation of Wishart distribution. cc neerajprad Key modifications: - `torch/distributions/wishart.py`: Clamp (Clip) float type values to calculate reciprocal in numerically stable manner, by using the `eps` value paired to each `torch.dtype` - `test/distributions/test_distributions.py`: Test Wishart distribution implementation in numerically unstable zones, i.e `df` values are at `ndim - 1 < df < ndim` where `ndim` is the one dimenstion of the Wishart parameter & sample matrix. Pull Request resolved: pytorch/pytorch#72059 Reviewed By: H-Huang Differential Revision: D34245091 Pulled By: neerajprad fbshipit-source-id: 1cd653c1d5c663346433e84fd0bbe2e590790908 (cherry picked from commit ef1da3ba465247f5777c3c40a90b96955c4281d0)
…(#72059) Summary: Maintanance of pytorch/pytorch#70377 Multiple modifications of the merged initial implementation of Wishart distribution. cc neerajprad Key modifications: - `torch/distributions/wishart.py`: Clamp (Clip) float type values to calculate reciprocal in numerically stable manner, by using the `eps` value paired to each `torch.dtype` - `test/distributions/test_distributions.py`: Test Wishart distribution implementation in numerically unstable zones, i.e `df` values are at `ndim - 1 < df < ndim` where `ndim` is the one dimenstion of the Wishart parameter & sample matrix. Pull Request resolved: pytorch/pytorch#72059 Reviewed By: H-Huang Differential Revision: D34245091 Pulled By: neerajprad fbshipit-source-id: 1cd653c1d5c663346433e84fd0bbe2e590790908 (cherry picked from commit ef1da3ba465247f5777c3c40a90b96955c4281d0)
…(#72059) Summary: Maintanance of pytorch/pytorch#70377 Multiple modifications of the merged initial implementation of Wishart distribution. cc neerajprad Key modifications: - `torch/distributions/wishart.py`: Clamp (Clip) float type values to calculate reciprocal in numerically stable manner, by using the `eps` value paired to each `torch.dtype` - `test/distributions/test_distributions.py`: Test Wishart distribution implementation in numerically unstable zones, i.e `df` values are at `ndim - 1 < df < ndim` where `ndim` is the one dimenstion of the Wishart parameter & sample matrix. Pull Request resolved: pytorch/pytorch#72059 Reviewed By: H-Huang Differential Revision: D34245091 Pulled By: neerajprad fbshipit-source-id: 1cd653c1d5c663346433e84fd0bbe2e590790908 (cherry picked from commit ef1da3ba465247f5777c3c40a90b96955c4281d0)
|
This pull request has been reverted by 0b117a3. To re-land this change, please open another pull request, assignthe same reviewers, fix the CI failures that caused the revert and make sure that the failing CI runs on the PR by applying the proper ciflow label (e.g., ciflow/trunk). |
…72993) Summary: Maintanance of #70377 Multiple modifications of the merged initial implementation of Wishart distribution. Key modifications: * torch/distributions/wishart.py: Clamp (Clip) float type values to calculate reciprocal in numerically stable manner, by using the eps value paired to each torch.dtype * test/distributions/test_distributions.py: Test Wishart distribution implementation in numerically unstable zones, i.e df values are at ndim - 1 < df < ndim where ndim is the one dimenstion of the Wishart parameter & sample matrix. Re-opened reverted PR #72059 cc neerajprad vadimkantorov Pull Request resolved: #72993 Reviewed By: samdow Differential Revision: D34853807 Pulled By: neerajprad fbshipit-source-id: eb62dca19bf8a934fdf59b4ffc58587447fe8378
…72993) Summary: Maintanance of #70377 Multiple modifications of the merged initial implementation of Wishart distribution. Key modifications: * torch/distributions/wishart.py: Clamp (Clip) float type values to calculate reciprocal in numerically stable manner, by using the eps value paired to each torch.dtype * test/distributions/test_distributions.py: Test Wishart distribution implementation in numerically unstable zones, i.e df values are at ndim - 1 < df < ndim where ndim is the one dimenstion of the Wishart parameter & sample matrix. Re-opened reverted PR #72059 cc neerajprad vadimkantorov Pull Request resolved: #72993 Reviewed By: samdow Differential Revision: D34853807 Pulled By: neerajprad fbshipit-source-id: eb62dca19bf8a934fdf59b4ffc58587447fe8378 (cherry picked from commit 99240c0)
…ytorch#72059) Summary: Maintanance of pytorch#70377 Multiple modifications of the merged initial implementation of Wishart distribution. cc neerajprad Key modifications: - `torch/distributions/wishart.py`: Clamp (Clip) float type values to calculate reciprocal in numerically stable manner, by using the `eps` value paired to each `torch.dtype` - `test/distributions/test_distributions.py`: Test Wishart distribution implementation in numerically unstable zones, i.e `df` values are at `ndim - 1 < df < ndim` where `ndim` is the one dimenstion of the Wishart parameter & sample matrix. Pull Request resolved: pytorch#72059 Reviewed By: H-Huang Differential Revision: D34245091 Pulled By: neerajprad fbshipit-source-id: 1cd653c1d5c663346433e84fd0bbe2e590790908 (cherry picked from commit ef1da3b)
…ytorch#72993) Summary: Maintanance of pytorch#70377 Multiple modifications of the merged initial implementation of Wishart distribution. Key modifications: * torch/distributions/wishart.py: Clamp (Clip) float type values to calculate reciprocal in numerically stable manner, by using the eps value paired to each torch.dtype * test/distributions/test_distributions.py: Test Wishart distribution implementation in numerically unstable zones, i.e df values are at ndim - 1 < df < ndim where ndim is the one dimenstion of the Wishart parameter & sample matrix. Re-opened reverted PR pytorch#72059 cc neerajprad vadimkantorov Pull Request resolved: pytorch#72993 Reviewed By: samdow Differential Revision: D34853807 Pulled By: neerajprad fbshipit-source-id: eb62dca19bf8a934fdf59b4ffc58587447fe8378 (cherry picked from commit 99240c0)
Maintanance of #70377
Multiple modifications of the merged initial implementation of Wishart distribution.
cc @neerajprad
Key modifications:
torch/distributions/wishart.py: Clamp (Clip) float type values to calculate reciprocal in numerically stable manner, by using theepsvalue paired to eachtorch.dtypetest/distributions/test_distributions.py: Test Wishart distribution implementation in numerically unstable zones, i.edfvalues are atndim - 1 < df < ndimwherendimis the one dimenstion of the Wishart parameter & sample matrix.