Skip to content

Add lowering for updated _scaled_mm (fixing submodules)#130422

Closed
yangsiyu007 wants to merge 16 commits intopytorch:mainfrom
yangsiyu007:scaled-mm-lowering
Closed

Add lowering for updated _scaled_mm (fixing submodules)#130422
yangsiyu007 wants to merge 16 commits intopytorch:mainfrom
yangsiyu007:scaled-mm-lowering

Conversation

@yangsiyu007
Copy link
Contributor

@yangsiyu007 yangsiyu007 commented Jul 10, 2024

Add the Inductor lowering for torch._scaled_mm, whose API was last updated in #128683.

The lowering does:

  • for tensor-wise scaling, auto-tune between the default ATen kernel (cuBLAS) and Triton kernel configurations.
  • for row-wise scaling, auto-tune between the default ATen kernel (CUTLASS kernel added in FP8 rowwise scaling #125204) and Triton kernel configurations.

The Triton kernel template is based on htyu/FBGEMM@3ad9031 (D56337896) by @choutim, without using SPLIT_K, and that of mm torch/_inductor/kernel/mm.py

Testing:

  • Logging shows max-autotune tuning (AUTOTUNE scaled_mm) for both tensor-wise and row-wise scaling when called with the two scaling types.

  • Row-wise scaling allows operator fusion between preceding pointwise/reduction op and amax/cast:

    • output code Evaluating m=256, n=256, k=256, fusion_case='pointwise', scaling_mode='row'
      • P1477224245 - 2 kernels
    • output code Evaluating m=2048, n=256, k=2048, fusion_case='reduction', scaling_mode='row'
      • P1477227340 - 2 kernels
  • UT python test/inductor/test_fp8.py -- TestFP8Lowering

Benchmarking

Eager/compiled tensor-wise/row-wise scaling for various shapes:
https://docs.google.com/spreadsheets/d/1VfWEVuyrwoWysfbS0_u2VHJ-PsdWkF1qIsiD60AzTes/edit?gid=2113587669#gid=2113587669

  • Some of the “compiled” cases are slightly slower than “eager”. It’s because max-autotune selected the ATen kernel in the compiled case, and I think the discrepancy is variance.

Eager/compiled tensor-wise/row-wise scaling with pointwise/reduction preceding op for various shapes:
https://docs.google.com/spreadsheets/d/1Nv07NrdffQIoDeMjo9E0V-E-EYrEN0WysO_bn1bc6ns/edit?gid=1715488446#gid=1715488446

Questions for reviewers:

  • Should the type of the accumulator ACC_TYPE always be in float32? If not, where is this type set (output layout?)?

Todo:

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang

@pytorch-bot
Copy link

pytorch-bot bot commented Jul 10, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/130422

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (3 Unrelated Failures)

As of commit 588f69a with merge base 4c2bcf9 (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@yangsiyu007 yangsiyu007 marked this pull request as ready for review July 12, 2024 16:32
@vkuzo
Copy link
Contributor

vkuzo commented Jul 15, 2024

Should use_fast_accum be default False (as it is in aten/src/ATen/native/native_functions.yaml)? Probably should be False

Matching the setting from the user via torch._scaled_mm would be good.


@instantiate_parametrized_tests
class TestFP8Lowering(TestCase):
@unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @drisspg , thoughts on how far away we are from also supporting the nuz float8 flavors here?

w_inverse_scale,
bias,
)
torch.testing.assert_close(y_eager, y_compiled, rtol=5e-1, atol=5e-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this match how other matmul codegen cases in inductor are tested?

Copy link
Contributor Author

@yangsiyu007 yangsiyu007 Jul 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rtol=5e-1, atol=5e-1 is used in e.g. test_valid_cast(). In the Triton matmul tutorial, for fp8 it tested atol=0.125, rtol=0 for M, N, K = 512 as a reference. Let me know if you have concerns!

Copy link
Contributor

@vkuzo vkuzo Jul 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the float8 test in the triton tutorial is comparing PyTorch bf16 matmul to triton float8 matmul. The comparison here is between eager mode float8 matmul and compiled float8 matmul - should that be tested with a tighter tolerance?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, let me check the actual eager vs compiled difference for some shapes to see what's a good strict tolerance.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vkuzo updated the tolerance for each after checking the actual differences for the shapes tested. I initially read 5e-1 as 1e-5 in my head which was not helpful...

Copy link
Contributor

@ipiszy ipiszy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @yangsiyu007 !

Should the type of the accumulator ACC_TYPE always be in float32? If not, where is this type set (output layout?)?

I'd suggest to use fp32 as default. Currently no use cases for fp16 accumulation.

Should use_fast_accum be default False (as it is in aten/src/ATen/native/native_functions.yaml)? Probably should be False

As long as there is a way to set use_fast_accum it should be okay. For inference, most of cases we just use fast_accum.


return dict(
GROUP_M=8,
EVEN_K=even_k_symbolic,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you help confirm how will this be used? Also, what's the alignment requirement for fp8 gemm from Triton and Cublas? Could you add some test cases for e.g. odd k, k divisible by 2 / 4 / 8 / 16? If it's not supported we should throw an error somewhere.

Copy link
Contributor Author

@yangsiyu007 yangsiyu007 Jul 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EVEN_K is used in the kernel template to load matrix blocks without masks when K is even. This is used in the mm's lowering as well.

[EDITED] Alignment requirement:

Tensorwise Cublas:
- K divisible by 16
- N divisible by 8

Rowwise Cutlass (same):
- K divisible by 16
- N divisible by 8

Triton (tensorwise and rowwise):
- K >= 16
- (no restrictions on N and M)

Perf for Triton for odd K, K divisible by 16 / 8 / 4 / 2: https://docs.google.com/spreadsheets/d/1KnBFXH-4aUbUXWWVyEKNEIToz5q5cjZ6bipTs7ObUXs/edit?gid=2086012542#gid=2086012542 (internal):

  • Odd K compared to K divisible by 16: 6x slower
  • Perf worse to best: odd, divisible by 4, divisible by 2 / 8 (about the same), divisible by 16.

For the compiled case, requirement for K and N is checked in _meta_registrations already. I'll add a test case to test the error message. Added test cases for rowwise/tensorwise for when M is (1, 3, 33, 257, 1024), and K, N are 16 or a larger multiple of 32.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Tensorwise cublas": is the alignment requirement on N real?

"Rowwise Cutlass: no requirement": I'm surprised about this. I don't really believe that the CUTLASS fp8 gemm kernel doesn't have requirements on K. Are we sure about this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I meant no additional requirements from Cutlass. Edited my comment above with more info (commenting out current 16-divisibility checks).

@vkuzo
Copy link
Contributor

vkuzo commented Jul 19, 2024

Row-wise scaling allows operator fusion between preceding pointwise/reduction op and amax/cast:

thats great! asking as an inductor n00b and not for this PR, what would it take to support epilogue fusion as well for float8 gemms?

@yangsiyu007
Copy link
Contributor Author

yangsiyu007 commented Jul 22, 2024

@vkuzo I should clarify that the fusion between the preceding op and the amax and cast is not considered "prologue fusion" for the scaled_mm op, which would mean that the preceding op is fused into the same kernel as scaled_mm (in P1477224245, there are still 2 kernels). AFAIK prologue fusion is not usually done for matmul ops, because it results in redundant computation of the preceding op each time a row/column is read by the matmul op.

Epilogue fusion is already done by Inductor when one of the Triton kernels (configurations of the template) is selected by max-autotune. Here's an example of sigmoid + scaled_mm + relu: P1489769109. The relu is fused to the end of the scaled_mm kernel tmp1 = triton_helpers.maximum(tmp0, acc), so there're still only 2 kernels in all. @ipiszy please correct me.

)
self.assertEqual(y_eager.dtype, dtype)
self.assertEqual(y_compiled.dtype, dtype)
torch.testing.assert_close(y_eager, y_compiled, rtol=5e-2, atol=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the tolerance is higher than I would have expected. If we compare, for example, the cuBLAS kernel (via torch._scaled_mm) versus the triton fp8 tutorial's matmul, is the tolerance similar?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see one issue is that I include input quantization inside the compiled function, making the difference greater than if only the matmul is compiled. Updated the UTs to only include matmul in the compiled function, and all rtol to 0.01.

I also realized that I made the mistake of using rand() (Uniform[0, 1)) instead of randn() (Normal(0, 1)) to generate input matrices. Using randn(), some result elements are close to 0, so the relative difference can be quite large (for M, K, N = 2, 1024, 2048, one element is -0.0432 in the compiled case and -0.0234 via the cuBLAS kernel, 80% relative difference). So I've increased the atol to 0.05 or 0.07 to allow for this (see below on justification).

If we compare, for example, the cuBLAS kernel (via torch._scaled_mm) versus the triton fp8 tutorial's matmul, is the tolerance similar?

The Triton template used for the compiled case follows that of the mm kernel (with scaling and bias added), which is expressed slightly differently than the tutorial's matmul but the approach is exactly the same. When I compared the cuBLAS kernel, tutorial's Triton and the templated Triton kernels, I did see larger relative difference (for element close to 0) between the templated Triton and cuBLAS for some shapes e.g. (3, 1024, 2048), (2, 1024, 2048. This turned out to be because certain configs (block M size etc) led to different results (expected as float addition is not associative - see my post in the internal Triton group), and once we use the same config for the tutorial's Triton kernel, then both Triton kernels had the same relative difference to cuBLAS. The difference is very small compared to how imprecise they are from the bfloat16 result. Repro script and output: P1496984171.

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Jul 29, 2024
@yangsiyu007
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: New commits were pushed while merging. Please rerun the merge command.

Details for Dev Infra team Raised by workflow job

@yangsiyu007
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@yangsiyu007
Copy link
Contributor Author

@pytorchbot merge -f "Failing tests TestModuleCUDA.test_cpu_gpu_parity_nn_CTCLoss_cuda_float32 and TestModuleCPU.test_non_contiguous_tensors_nn_KLDivLoss_cpu_float32 are failing on main already."

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@ZainRizvi
Copy link
Contributor

@pytorchbot revert -c ghfirst -m "Breaks internal tests. See D60568116"

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

Reverting PR 130422 failed

Reason: Command git -C /home/runner/work/pytorch/pytorch revert --no-edit 882d80fd924548711e01650232c841403db280c4 returned non-zero exit code 1

CONFLICT (modify/delete): torch/_inductor/kernel/mm_scaled.py deleted in parent of 882d80fd92 (Add lowering for updated _scaled_mm (fixing submodules) (#130422)) and modified in HEAD.  Version HEAD of torch/_inductor/kernel/mm_scaled.py left in tree.
Auto-merging torch/_inductor/lowering.py
Auto-merging torch/_inductor/utils.py
error: could not revert 882d80fd92... Add lowering for updated _scaled_mm (fixing submodules) (#130422)
hint: After resolving the conflicts, mark them with
hint: "git add/rm <pathspec>", then run
hint: "git revert --continue".
hint: You can instead skip this commit with "git revert --skip".
hint: To abort and get back to the state before "git revert",
hint: run "git revert --abort".
hint: Disable this message with "git config advice.mergeConflict false"
Details for Dev Infra team Raised by workflow job

yangsiyu007 added a commit to yangsiyu007/pytorch that referenced this pull request Aug 1, 2024
Summary:
pytorch#130422 caused the test `test.inductor.test_aot_inductor.AOTInductorTestABICompatibleCuda. test_fp8_abi_compatible_cuda` to fail (unclear why it was not run in GitHub) with `torch/csrc/inductor/aoti_torch/c/shim.h:390:34: note: candidate function not viable: requires 9 arguments, but 6 were provided`. We suspect that the kernel produced by the lowering function, which is no longer a fallback choice, has a schema issue at codegen. Fp8 is not used through AOTI currently and it is difficult to revert the PR (BE week), so we'll skip the test temporarily while making the new lowering compatible with AOTI.

Testing: the failed test on internal diff is now skipped.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

Pull Request resolved: pytorch#132453

Reviewed By: henrylhtsang

Differential Revision: D60618355

Pulled By: yangsiyu007
pytorchmergebot pushed a commit that referenced this pull request Aug 2, 2024
#130422 caused the test `test.inductor.test_aot_inductor.AOTInductorTestABICompatibleCuda. test_fp8_abi_compatible_cuda` to fail (unclear why it was not run in GitHub) with `torch/csrc/inductor/aoti_torch/c/shim.h:390:34: note: candidate function not viable: requires 9 arguments, but 6 were provided`. We suspect that the kernel produced by the lowering function, which is no longer a fallback choice, has a schema issue at codegen. Fp8 is not used through AOTI currently and it is difficult to revert the PR (BE week), so we'll skip the test temporarily while making the new lowering compatible with AOTI.

Testing: the failed test on internal diff is now skipped.

Pull Request resolved: #132453
Approved by: https://github.com/henrylhtsang
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants