[MPS] Adding lgamma, digamma, and polygamma implementations#106292
[MPS] Adding lgamma, digamma, and polygamma implementations#106292igm503 wants to merge 14 commits intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/106292
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New Failures, 8 Unrelated FailuresAs of commit cc34d98 with merge base 703cdd7 ( NEW FAILURES - The following jobs have failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@kulinseth Any chance you can give this a look and advise about whether the test failures are a problem? |
This issue seems unrelated to the PR. Can you @igm503 please rebase the PR? |
048b88a to
3938014
Compare
@igm503 the assertion is coming from not implemented test . Can you check if lgamma tests are not in that category class in test_mps . |
|
@kulinseth I've fixed the assertion error by swapping another not-yet-implemented op for lgamma in the not_implemented test. |
|
@kulinseth So, at least as I'm typing this, the test errors are now those that I mentioned in the pull request body: in some cases, they're precision issues, but in other cases, I think the cpu implementation is incorrect. |
|
The tests now pass on the macos 13 builds. @kulinseth However, since there are precision issues with test_output_grad_match_polygamma_polygamma_n_0_cpu_float32 on macos 12 as well, where should I put that exception? I scanned the different XFAILLISTs, and I don't see a clear place for it. Of course, I could put it in the pre-13 XFAIL list, but that would make it seem like it's fixed for >13, which it isn't. |
…unction had been made static elsewhere
…or test_error_on_not_implemented
|
@kulinseth I went ahead and added the failing tests to the MACOS_BEFORE_13_3_XFAILLIST as well. Let me know if there's a more appropriate place to put them. |
|
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 6 checks: pull / linux-focal-py3.8-clang10 / test (default, 2, 3, linux.2xlarge), pull / linux-jammy-py3.9-clang12-asan / test (default, 6, 6, linux.4xlarge), pull / linux-jammy-py3.8-gcc11 / test (default, 2, 3, linux.2xlarge), pull / linux-focal-py3.11-clang10 / test (default, 3, 3, linux.2xlarge), pull / linux-focal-cuda12.1-py3.10-gcc9-sm86 / test (default, 4, 5, linux.g5.4xlarge.nvidia.gpu, unstable), pull / linux-focal-cuda12.1-py3.10-gcc9 / test (default, 1, 5, linux.4xlarge.nvidia.gpu) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / macos-12-py3-arm64 / test (default, 3, 3, macos-m1-12) Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge -i |
Fixes issue mentioned in #77764
e.g. #77764 (comment)
Adds MPS support for the following ops:
The lgamma fucntion does not yet have an MPS backend implementation. I've added one using a custom metal kernel (following John D. Cook's c++ implementation of the log gamma function: https://www.johndcook.com/blog/cpp_gamma/). For the backward pass op, I've added a digamma kernel that follows the cpu+cuda digamma implementation, and for the backward pass of the digamma op, I've added a polygamma + trigamma kernel following, again, the cpu+cuda implementations.
NOTE:
The cpu implementation of the polygamma function incorrectly (as far as I can tell) outputs a finite number for order = 1 and x in the negative integers. The mps implementation correctly outputs infinite. (see #106692)
The polygamma tests currently don't pass because of the error in the cpu+cuda kernels, but also because there are smallish discrepancies near the negative integers between the cpu+cuda and the mps polygamma and trigamma kernels. I'm not sure exactly why this is, but let me know if the discrepancies are too big.