faster per_token_group_quant_8bit kernel with stride support#8290
faster per_token_group_quant_8bit kernel with stride support#8290strgrb wants to merge 9 commits intosgl-project:mainfrom
Conversation
There was a problem hiding this comment.
Summary of Changes
Hello @strgrb, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces a significant rewrite of the per_token_group_quant_8bit CUDA kernel, leveraging the cute library for improved performance and flexibility. The primary goals are to enhance quantization capabilities by supporting float scales with ue8m0 and enabling stride-aware operations for future fusions. This new kernel aims to provide a more robust and efficient quantization solution, particularly beneficial for Hopper GPUs.
Highlights
- New Kernel Implementation: I've replaced the existing
per_token_group_quant_8bitkernel with a new, faster implementation written using thecuteCUDA library. This rewrite aims to improve efficiency and flexibility. - Enhanced Functionality: This PR adds support for
floatscale data types even whenue8m0quantization is enabled, which is crucial for testing on Hopper GPUs. Additionally, I've introduced stride support within the kernel, paving the way for future kernel fusion optimizations. - Performance Improvements: Benchmarking on Hopper GPUs shows performance gains, particularly for
float8_e4m3fnoutput types, while maintaining competitive performance forint8outputs. This indicates a more optimized path for certain quantization scenarios. - Robustness and Correctness: I've addressed a potential issue in the original kernel's scale calculation where
hidden_dim // group_sizewas not divisible by 4, leading to incorrectrow_size. The new implementation aims for a more robust and correct scale determination. - Dynamic Dispatch: A C++ dispatcher has been implemented to dynamically select the appropriate kernel based on input, output, and scale data types (Float, Half, BFloat16 for input; Float8_e4m3fn, Char for output; Float, Int for scale) and supported group sizes (32, 64, 128).
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Code Review
The pull request introduces a new, faster per_token_group_quant_8bit kernel. I've identified a critical bug in the reduction loop, opportunities for code cleanup, and inconsistencies in naming and documentation that could affect maintainability. Addressing these will improve the kernel's correctness and maintainability.
|
@strgrb You're right, this cute implementation is really cool! Are you planning to help us fix the bug in the original kernel where |
|
@strgrb By the way, where does this cute implementation get its main speedup compared to the original, and do you have any NCU analysis results or similar? Thanks. |
In my kernel, I avoid reload input again, but I see L2 cache hit, so it's not the problem since it's a compute-intensive kernel. By checking the SASS code in ncu profile, I finally found the cause, the original kernel do not use fp8 instrinct to convert . Following is sass code from my kernel. I guess the reason maybe the original kernel use |
I have some other plan this week, I'll take a look at next week. |
@strgrb Hi, I want to ask a simple question because I have never read the original kernel code. What will happen when |
@HydraQYH According to |

Motivation
I want some new features for per_token_group_quant_8bit for future optimize:
So I decided to rewrite it with cute, and got some benefits.
Following test is tested on H20:
and in this case, the result differs, but I think my impl is correct:
If
hidden_dim // group_sizecannot be divided by 4, the row_size calculated here is not correct.sglang/sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu
Line 56 in 8abd3e7
Modifications
A new kernel written with cute.
Checklist