Skip to content

faster per_token_group_quant_8bit kernel with stride support#8290

Closed
strgrb wants to merge 9 commits intosgl-project:mainfrom
strgrb:quant
Closed

faster per_token_group_quant_8bit kernel with stride support#8290
strgrb wants to merge 9 commits intosgl-project:mainfrom
strgrb:quant

Conversation

@strgrb
Copy link
Copy Markdown
Collaborator

@strgrb strgrb commented Jul 23, 2025

Motivation

I want some new features for per_token_group_quant_8bit for future optimize:

  • support scale.dtype=float even if ue8m0 is on, to test ue8m0 quant on hopper gpus
  • support stride for future fusion
    So I decided to rewrite it with cute, and got some benefits.

Following test is tested on H20:

per-token-group-quant-8bit-performance:
   batch_size  seq_len  group_size            dst_dtype  SGL Kernel  new kernel
0           1     1024         128           torch.int8   36.800001    36.063999
1           1     1024         128  torch.float8_e4m3fn   37.983999    32.992002
2           1     4096         128           torch.int8  117.183998   117.311999
3           1     4096         128  torch.float8_e4m3fn  128.191993   113.407999
4           1     8192         128           torch.int8  223.488003   223.391995
5           1     8192         128  torch.float8_e4m3fn  244.448006   218.367994
6           1    16384         128           torch.int8  435.167998   435.615987
7           1    16384         128  torch.float8_e4m3fn  477.600008   426.815987

and in this case, the result differs, but I think my impl is correct:

import sgl_kernel
import torch
from sgl_kernel.gemm import sgl_per_token_group_quant_8bit as per_token_group_quant_with_prologue
from sgl_kernel import sgl_per_token_group_quant_fp8
import torch

a = torch.randn([3, 128], device='cuda', dtype=torch.bfloat16)
group_size = 128

def allocate_qs(input, q_dtype, aligned=False, scale_ue8m0=False):
  q = torch.empty_like(input, device='cuda', dtype=q_dtype)
  m, n = input.shape
  if scale_ue8m0:
    am = (m + 3) // 4 * 4
    an = (n // group_size + 3) // 4 * 4
    s = torch.zeros([an // 4, am], device='cuda', dtype=torch.int).T[:m, :]
  elif aligned:
    am = (m + 3) // 4 * 4
    s = torch.empty([n // group_size, am], device='cuda', dtype=torch.float32).T[:m, :]
  else:
    s = torch.empty([m, n // group_size], device='cuda', dtype=torch.float32)
  return q, s

def my_quant(q_dtype, aligned, scale_ue8m0):
  q, s = allocate_qs(a, q_dtype, aligned, scale_ue8m0)
  per_token_group_quant_with_prologue(a, q, s, group_size, 1e-10, -448, 448, scale_ue8m0)
  return q, s

def sgl_quant(q_dtype, aligned, scale_ue8m0):
  q, s = allocate_qs(a, q_dtype, aligned, scale_ue8m0)
  sgl_per_token_group_quant_fp8(a, q, s, group_size, 1e-10, -448, 448, scale_ue8m0=scale_ue8m0)
  return q, s

q_dtype = torch.int8
aligned = True
scale_ue8m0 = True
q1, s1 = my_quant(q_dtype, aligned, scale_ue8m0)
q2, s2 = sgl_quant(q_dtype, aligned, scale_ue8m0)

If hidden_dim // group_size cannot be divided by 4, the row_size calculated here is not correct.

const int scale_num_rows_element = scale_num_rows * num_elems_per_pack;

Modifications

A new kernel written with cute.

Checklist

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @strgrb, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant rewrite of the per_token_group_quant_8bit CUDA kernel, leveraging the cute library for improved performance and flexibility. The primary goals are to enhance quantization capabilities by supporting float scales with ue8m0 and enabling stride-aware operations for future fusions. This new kernel aims to provide a more robust and efficient quantization solution, particularly beneficial for Hopper GPUs.

Highlights

  • New Kernel Implementation: I've replaced the existing per_token_group_quant_8bit kernel with a new, faster implementation written using the cute CUDA library. This rewrite aims to improve efficiency and flexibility.
  • Enhanced Functionality: This PR adds support for float scale data types even when ue8m0 quantization is enabled, which is crucial for testing on Hopper GPUs. Additionally, I've introduced stride support within the kernel, paving the way for future kernel fusion optimizations.
  • Performance Improvements: Benchmarking on Hopper GPUs shows performance gains, particularly for float8_e4m3fn output types, while maintaining competitive performance for int8 outputs. This indicates a more optimized path for certain quantization scenarios.
  • Robustness and Correctness: I've addressed a potential issue in the original kernel's scale calculation where hidden_dim // group_size was not divisible by 4, leading to incorrect row_size. The new implementation aims for a more robust and correct scale determination.
  • Dynamic Dispatch: A C++ dispatcher has been implemented to dynamically select the appropriate kernel based on input, output, and scale data types (Float, Half, BFloat16 for input; Float8_e4m3fn, Char for output; Float, Int for scale) and supported group sizes (32, 64, 128).
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces a new, faster per_token_group_quant_8bit kernel. I've identified a critical bug in the reduction loop, opportunities for code cleanup, and inconsistencies in naming and documentation that could affect maintainability. Addressing these will improve the kernel's correctness and maintainability.

Comment thread sgl-kernel/csrc/gemm/per_token_group_quant_8bit_prologue.cu Outdated
Comment thread sgl-kernel/csrc/gemm/per_token_group_quant_8bit_prologue.cu
Comment thread sgl-kernel/csrc/gemm/per_token_group_quant_8bit_prologue.cu Outdated
Comment thread sgl-kernel/csrc/gemm/per_token_group_quant_8bit_prologue.cu Outdated
Comment thread sgl-kernel/csrc/common_extension.cc Outdated
Comment thread sgl-kernel/python/sgl_kernel/gemm.py
@BBuf
Copy link
Copy Markdown
Collaborator

BBuf commented Jul 23, 2025

@strgrb You're right, this cute implementation is really cool! Are you planning to help us fix the bug in the original kernel where hidden_dim // group_size isn't divisible by 4?

@BBuf
Copy link
Copy Markdown
Collaborator

BBuf commented Jul 23, 2025

@strgrb By the way, where does this cute implementation get its main speedup compared to the original, and do you have any NCU analysis results or similar? Thanks.

@strgrb
Copy link
Copy Markdown
Collaborator Author

strgrb commented Jul 24, 2025

@strgrb By the way, where does this cute implementation get its main speedup compared to the original, and do you have any NCU analysis results or similar? Thanks.

In my kernel, I avoid reload input again, but I see L2 cache hit, so it's not the problem since it's a compute-intensive kernel. By checking the SASS code in ncu profile, I finally found the cause, the original kernel do not use fp8 instrinct to convert . Following is sass code from my kernel.
image

I guess the reason maybe the original kernel use c10::Float8_e4m3 , not __nv_fp8_e4m3?

@strgrb
Copy link
Copy Markdown
Collaborator Author

strgrb commented Jul 24, 2025

@strgrb You're right, this cute implementation is really cool! Are you planning to help us fix the bug in the original kernel where hidden_dim // group_size isn't divisible by 4?

I have some other plan this week, I'll take a look at next week.

@HydraQYH
Copy link
Copy Markdown
Collaborator

HydraQYH commented Jul 24, 2025

@strgrb You're right, this cute implementation is really cool! Are you planning to help us fix the bug in the original kernel where hidden_dim // group_size isn't divisible by 4?

I have some other plan this week, I'll take a look at next week.

@strgrb Hi, I want to ask a simple question because I have never read the original kernel code. What will happen when hidden_dim // group_size isn't divisible by 4?
I'm guessing it's because of some additional alignment requirement, but I'm not sure.

@strgrb
Copy link
Copy Markdown
Collaborator Author

strgrb commented Jul 24, 2025

@strgrb You're right, this cute implementation is really cool! Are you planning to help us fix the bug in the original kernel where hidden_dim // group_size isn't divisible by 4?

I have some other plan this week, I'll take a look at next week.

@strgrb Hi, I want to ask a simple question because I have never read the original kernel code. What will happen when hidden_dim // group_size isn't divisible by 4? I'm guessing it's because of some additional alignment requirement, but I'm not sure.

@HydraQYH According to fp8_kernel.py , both token_count dim and hidden dim are aligned to 16B if ue8m0 is on. Since token_count dim is contiguous, it should be aligned, and ue8m0 scale is packed to int32 with hidden dim, so hidden dim should also be aligned. The unpacked col_idx should be calculated by real hidden_dim // group_size .

Comment thread sgl-kernel/csrc/gemm/per_token_group_quant_8bit_prologue.cu
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants