Add lowering for updated _scaled_mm (fixing submodules)#130422
Add lowering for updated _scaled_mm (fixing submodules)#130422yangsiyu007 wants to merge 16 commits intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/130422
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (3 Unrelated Failures)As of commit 588f69a with merge base 4c2bcf9 ( BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Matching the setting from the user via |
|
|
||
| @instantiate_parametrized_tests | ||
| class TestFP8Lowering(TestCase): | ||
| @unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM") |
There was a problem hiding this comment.
cc @drisspg , thoughts on how far away we are from also supporting the nuz float8 flavors here?
test/inductor/test_fp8.py
Outdated
| w_inverse_scale, | ||
| bias, | ||
| ) | ||
| torch.testing.assert_close(y_eager, y_compiled, rtol=5e-1, atol=5e-1) |
There was a problem hiding this comment.
does this match how other matmul codegen cases in inductor are tested?
There was a problem hiding this comment.
rtol=5e-1, atol=5e-1 is used in e.g. test_valid_cast(). In the Triton matmul tutorial, for fp8 it tested atol=0.125, rtol=0 for M, N, K = 512 as a reference. Let me know if you have concerns!
There was a problem hiding this comment.
the float8 test in the triton tutorial is comparing PyTorch bf16 matmul to triton float8 matmul. The comparison here is between eager mode float8 matmul and compiled float8 matmul - should that be tested with a tighter tolerance?
There was a problem hiding this comment.
You're right, let me check the actual eager vs compiled difference for some shapes to see what's a good strict tolerance.
There was a problem hiding this comment.
@vkuzo updated the tolerance for each after checking the actual differences for the shapes tested. I initially read 5e-1 as 1e-5 in my head which was not helpful...
ipiszy
left a comment
There was a problem hiding this comment.
Thanks @yangsiyu007 !
Should the type of the accumulator ACC_TYPE always be in float32? If not, where is this type set (output layout?)?
I'd suggest to use fp32 as default. Currently no use cases for fp16 accumulation.
Should use_fast_accum be default False (as it is in aten/src/ATen/native/native_functions.yaml)? Probably should be False
As long as there is a way to set use_fast_accum it should be okay. For inference, most of cases we just use fast_accum.
|
|
||
| return dict( | ||
| GROUP_M=8, | ||
| EVEN_K=even_k_symbolic, |
There was a problem hiding this comment.
Could you help confirm how will this be used? Also, what's the alignment requirement for fp8 gemm from Triton and Cublas? Could you add some test cases for e.g. odd k, k divisible by 2 / 4 / 8 / 16? If it's not supported we should throw an error somewhere.
There was a problem hiding this comment.
EVEN_K is used in the kernel template to load matrix blocks without masks when K is even. This is used in the mm's lowering as well.
[EDITED] Alignment requirement:
Tensorwise Cublas:
- K divisible by 16
- N divisible by 8
Rowwise Cutlass (same):
- K divisible by 16
- N divisible by 8
Triton (tensorwise and rowwise):
- K >= 16
- (no restrictions on N and M)
Perf for Triton for odd K, K divisible by 16 / 8 / 4 / 2: https://docs.google.com/spreadsheets/d/1KnBFXH-4aUbUXWWVyEKNEIToz5q5cjZ6bipTs7ObUXs/edit?gid=2086012542#gid=2086012542 (internal):
- Odd K compared to K divisible by 16: 6x slower
- Perf worse to best: odd, divisible by 4, divisible by 2 / 8 (about the same), divisible by 16.
For the compiled case, requirement for K and N is checked in _meta_registrations already. I'll add a test case to test the error message. Added test cases for rowwise/tensorwise for when M is (1, 3, 33, 257, 1024), and K, N are 16 or a larger multiple of 32.
There was a problem hiding this comment.
"Tensorwise cublas": is the alignment requirement on N real?
"Rowwise Cutlass: no requirement": I'm surprised about this. I don't really believe that the CUTLASS fp8 gemm kernel doesn't have requirements on K. Are we sure about this?
There was a problem hiding this comment.
Sorry I meant no additional requirements from Cutlass. Edited my comment above with more info (commenting out current 16-divisibility checks).
thats great! asking as an inductor n00b and not for this PR, what would it take to support epilogue fusion as well for float8 gemms? |
|
@vkuzo I should clarify that the fusion between the preceding op and the amax and cast is not considered "prologue fusion" for the scaled_mm op, which would mean that the preceding op is fused into the same kernel as scaled_mm (in P1477224245, there are still 2 kernels). AFAIK prologue fusion is not usually done for matmul ops, because it results in redundant computation of the preceding op each time a row/column is read by the matmul op. Epilogue fusion is already done by Inductor when one of the Triton kernels (configurations of the template) is selected by max-autotune. Here's an example of sigmoid + scaled_mm + relu: P1489769109. The relu is fused to the end of the scaled_mm kernel |
test/inductor/test_fp8.py
Outdated
| ) | ||
| self.assertEqual(y_eager.dtype, dtype) | ||
| self.assertEqual(y_compiled.dtype, dtype) | ||
| torch.testing.assert_close(y_eager, y_compiled, rtol=5e-2, atol=0) |
There was a problem hiding this comment.
the tolerance is higher than I would have expected. If we compare, for example, the cuBLAS kernel (via torch._scaled_mm) versus the triton fp8 tutorial's matmul, is the tolerance similar?
There was a problem hiding this comment.
I see one issue is that I include input quantization inside the compiled function, making the difference greater than if only the matmul is compiled. Updated the UTs to only include matmul in the compiled function, and all rtol to 0.01.
I also realized that I made the mistake of using rand() (Uniform[0, 1)) instead of randn() (Normal(0, 1)) to generate input matrices. Using randn(), some result elements are close to 0, so the relative difference can be quite large (for M, K, N = 2, 1024, 2048, one element is -0.0432 in the compiled case and -0.0234 via the cuBLAS kernel, 80% relative difference). So I've increased the atol to 0.05 or 0.07 to allow for this (see below on justification).
If we compare, for example, the cuBLAS kernel (via torch._scaled_mm) versus the triton fp8 tutorial's matmul, is the tolerance similar?
The Triton template used for the compiled case follows that of the mm kernel (with scaling and bias added), which is expressed slightly differently than the tutorial's matmul but the approach is exactly the same. When I compared the cuBLAS kernel, tutorial's Triton and the templated Triton kernels, I did see larger relative difference (for element close to 0) between the templated Triton and cuBLAS for some shapes e.g. (3, 1024, 2048), (2, 1024, 2048. This turned out to be because certain configs (block M size etc) led to different results (expected as float addition is not associative - see my post in the internal Triton group), and once we use the same config for the tutorial's Triton kernel, then both Triton kernels had the same relative difference to cuBLAS. The difference is very small compared to how imprecise they are from the bfloat16 result. Repro script and output: P1496984171.
… result) and catch NoValidChoicesError
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: New commits were pushed while merging. Please rerun the merge command. Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
Needed to sync with origin and run `with-proxy lintrunner -m origin/main` to reproduce.
|
@pytorchbot merge -f "Failing tests TestModuleCUDA.test_cpu_gpu_parity_nn_CTCLoss_cuda_float32 and TestModuleCPU.test_non_contiguous_tensors_nn_KLDivLoss_cpu_float32 are failing on main already." |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
@pytorchbot revert -c ghfirst -m "Breaks internal tests. See D60568116" |
|
@pytorchbot successfully started a revert job. Check the current status here. |
Reverting PR 130422 failedReason: Command Details for Dev Infra teamRaised by workflow job |
Summary: pytorch#130422 caused the test `test.inductor.test_aot_inductor.AOTInductorTestABICompatibleCuda. test_fp8_abi_compatible_cuda` to fail (unclear why it was not run in GitHub) with `torch/csrc/inductor/aoti_torch/c/shim.h:390:34: note: candidate function not viable: requires 9 arguments, but 6 were provided`. We suspect that the kernel produced by the lowering function, which is no longer a fallback choice, has a schema issue at codegen. Fp8 is not used through AOTI currently and it is difficult to revert the PR (BE week), so we'll skip the test temporarily while making the new lowering compatible with AOTI. Testing: the failed test on internal diff is now skipped. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang Pull Request resolved: pytorch#132453 Reviewed By: henrylhtsang Differential Revision: D60618355 Pulled By: yangsiyu007
#130422 caused the test `test.inductor.test_aot_inductor.AOTInductorTestABICompatibleCuda. test_fp8_abi_compatible_cuda` to fail (unclear why it was not run in GitHub) with `torch/csrc/inductor/aoti_torch/c/shim.h:390:34: note: candidate function not viable: requires 9 arguments, but 6 were provided`. We suspect that the kernel produced by the lowering function, which is no longer a fallback choice, has a schema issue at codegen. Fp8 is not used through AOTI currently and it is difficult to revert the PR (BE week), so we'll skip the test temporarily while making the new lowering compatible with AOTI. Testing: the failed test on internal diff is now skipped. Pull Request resolved: #132453 Approved by: https://github.com/henrylhtsang
Add the Inductor lowering for
torch._scaled_mm, whose API was last updated in #128683.The lowering does:
The Triton kernel template is based on htyu/FBGEMM@3ad9031 (D56337896) by @choutim, without using SPLIT_K, and that of mm
torch/_inductor/kernel/mm.pyTesting:
Logging shows max-autotune tuning (
AUTOTUNE scaled_mm) for both tensor-wise and row-wise scaling when called with the two scaling types.Row-wise scaling allows operator fusion between preceding pointwise/reduction op and amax/cast:
UT
python test/inductor/test_fp8.py -- TestFP8LoweringBenchmarking
Eager/compiled tensor-wise/row-wise scaling for various shapes:
https://docs.google.com/spreadsheets/d/1VfWEVuyrwoWysfbS0_u2VHJ-PsdWkF1qIsiD60AzTes/edit?gid=2113587669#gid=2113587669
Eager/compiled tensor-wise/row-wise scaling with pointwise/reduction preceding op for various shapes:
https://docs.google.com/spreadsheets/d/1Nv07NrdffQIoDeMjo9E0V-E-EYrEN0WysO_bn1bc6ns/edit?gid=1715488446#gid=1715488446
Questions for reviewers:
ACC_TYPEalways be in float32? If not, where is this type set (output layout?)?Todo:
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang