[Quantization][Perf] add triton w8a8 int8 gemm kernel#10502
[Quantization][Perf] add triton w8a8 int8 gemm kernel#10502ZelinMa557 wants to merge 8 commits intosgl-project:mainfrom
Conversation
There was a problem hiding this comment.
Summary of Changes
Hello @ZelinMa557, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request significantly enhances the performance of int8 quantization within the system by integrating a new, highly optimized Triton kernel for weight-8-bit, activation-8-bit (w8a8) General Matrix Multiply (GEMM) operations. This specialized kernel is designed to leverage the capabilities of modern NVIDIA GPUs (sm89 architecture) and provides substantial speed improvements, particularly for larger batch sizes, making quantized model inference more efficient.
Highlights
- New Triton Kernel: Introduced a Triton-based w8a8 int8 GEMM kernel for improved performance in quantized operations.
- Performance Boost: Achieves up to 50% speedup over the existing
sgl-kernelfor matrix sizes where M (batch size) is greater than or equal to 1024. - Hardware Specific Optimization: The new kernel is specifically enabled for NVIDIA GPUs with compute capability
sm89(e.g., RTX 4090) to leverage their architectural advantages. - Accuracy Verified: Includes new unit tests to ensure the accuracy of the implemented int8 GEMM kernel against a native PyTorch reference.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in pull request comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Code Review
This pull request introduces a new Triton kernel for w8a8 int8 GEMM, aimed at improving performance on sm89 architectures. The benchmarks provided demonstrate a significant speedup for larger matrix sizes, which is a great addition. The implementation is well-structured, and the inclusion of a new unit test ensures the correctness of the new kernel. My review includes a few suggestions to enhance code clarity, maintainability, and the robustness of the tests.
| if C.dtype.element_ty == tl.bfloat16: | ||
| c = accumulator.to(tl.bfloat16) | ||
| elif C.dtype.element_ty == tl.float16: | ||
| c = accumulator.to(tl.float16) |
| triton.cdiv(M, META["BLOCK_SIZE_M"]) * | ||
| triton.cdiv(N, META["BLOCK_SIZE_N"]), | ||
| ) | ||
| bias_ptr = bias if bias is not None else B |
There was a problem hiding this comment.
| self.assertTrue( | ||
| torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) | ||
| / torch.mean(torch.abs(ref_out.to(torch.float32))) | ||
| < 0.05 | ||
| ) |
There was a problem hiding this comment.
The current assertion checks the mean relative error, which might not catch large discrepancies in a small number of elements. Using torch.allclose provides a more robust, element-wise comparison and is generally preferred for tensor comparisons in tests. This will ensure that all elements in the output tensor are within an acceptable tolerance.
self.assertTrue(torch.allclose(out, ref_out, rtol=0.05, atol=1e-1))Signed-off-by: ZelinMa557 <3388706467@qq.com>
|
fix lint |
HydraQYH
left a comment
There was a problem hiding this comment.
Great job! Could you please reply to my comment? If it is convenient for you, could you provide a more detailed performance analysis report(e.g. ncu) to prove the source of the performance improvement?
|
Hi, thank you for your patient code review! @HydraQYH I will attach some ncu profile results tomorrow. |
Signed-off-by: ZelinMa557 <3388706467@qq.com>
|
This is the ncu perf result, ncu indicates that the L2 access pattern of the cutlass might be sub-optimal: I attached the ncu perf file here: Also, I have refacted the test code. Maybe you can take a look? @HydraQYH |
|
@ZelinMa557 I analyzed the ncu report and found that the CUTLASS-based kernel used inefficient IMMA instructions: |
|
Hi, I saw the performance report is the new pr, the performance boost of the fixed cutlass kernel is higher than the triton one, so I think there is no need to benchmark the triton kernel again. @HydraQYH However, maybe we can keep this triton kernel for other platforms, such as hip? |
@ZelinMa557 This is a good idea. But you need to adapt your code to be used only on the HIP platform. |
HydraQYH
left a comment
There was a problem hiding this comment.
Adapting the HIP platform requires addressing the following comments.
Thanks, I have updated the code. Do we need to re-tune the configs of this triton kernel on HIP devices? |
|
@HaiShaw @saienduri may you help take a look? thanks |
Signed-off-by: ZelinMa557 <3388706467@qq.com>
|
@HaiShaw hi, can you take a look at this pr? |



Motivation
At first I added this triton kernel for nvidia sm89 GPUs since the performance of the cutlass kernel was poor, and then @HydraQYH helped to improve the performance of the cutlass kernel.
I think maybe this kernel can be kept for hip platform, since sglang do not support w8a8 int8 on hip now.
Modifications
Add a triton w8a8 int8 gemm kernel, and only execute it on hip platform.
Accuracy Tests
I add an unit test in
test/srt/quant/test_int8_kernel.pyBenchmarking and Profiling
Benchmark with the following script:
Benchmark result:
Checklist