Skip to content

Support hidden_dim % 4 == 0 in per_token_quant_fp8#12883

Merged
ispobock merged 3 commits intomainfrom
support_hidden_dim_mod_4_equals_zero_in_per_token_quant
Nov 10, 2025
Merged

Support hidden_dim % 4 == 0 in per_token_quant_fp8#12883
ispobock merged 3 commits intomainfrom
support_hidden_dim_mod_4_equals_zero_in_per_token_quant

Conversation

@BBuf
Copy link
Copy Markdown
Collaborator

@BBuf BBuf commented Nov 8, 2025

Summary

This PR extends sgl_per_token_quant_fp8 to support hidden dimensions divisible by 4 (previously required divisibility by 8).

Changes

  • Relaxed TORCH_CHECK from hidden_dim % 8 == 0 to hidden_dim % 4 == 0
  • Added kVecSize=4 code path for both kernel variants
  • Updated test cases to include hidden_dim=1076

Motivation

Fixes deployment issue with qwen3-30b-a3b-moe-fp8 model which has layers with hidden_dim=1076 (1076 % 4 == 0 but 1076 % 8 != 0).

The model is generate by following script and serving with TP4 :

import argparse
import os
from transformers import AutoTokenizer, Qwen3VLMoeForConditionalGeneration
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
def quant(model_path, offload_folder=None, output_path=None):
    if offload_folder is None:
        offload_folder = "./offload"
    os.makedirs(offload_folder, exist_ok=True)
    model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
        model_path, 
        device_map="auto",
        torch_dtype="auto",
        offload_folder=offload_folder,
        offload_state_dict=True,
    )
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    # Configure the simple PTQ quantization
    recipe = QuantizationModifier(targets="Linear",
                                  scheme="FP8_DYNAMIC",
                                  ignore=["lm_head", "re:.*mlp.gate$"])
    # Apply the quantization algorithm.
    oneshot(model=model, recipe=recipe)
    # Save the model.
    # SAVE_DIR = model_path.split("/")[1] + "-FP8-Dynamic"
    if output_path is None:
        SAVE_DIR = model_path + "-FP8-Dynamic"
    else:
        SAVE_DIR = output_path + "-FP8-Dynamic"
    model.save_pretrained(SAVE_DIR)
    tokenizer.save_pretrained(SAVE_DIR)
if __name__ == "__main__":
    parser = argparse.ArgumentParser(prog='fp8_quant', description='fp8 quant')
    parser.add_argument("--model_path", type=str, required=True)
    parser.add_argument("--output_path", type=str, required=True)
    parser.add_argument("--offload_folder", type=str, default="./offload")
    args = parser.parse_args()
    print(f"run fp8 quant: {args.model_path}")
    print(f"offload folder: {args.offload_folder}")
    print(f"output_path folder: {args.output_path}")
    quant(args.model_path, args.offload_folder, args.output_path)

And when serving with sglang tp4, the bug happend:

"/mnt/data/data/bbuf/sglang/python/sglang/srt/layers/linear.py", line 1380, in forward
    output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
  File "/mnt/data/data/bbuf/sglang/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py", line 641, in apply
    return scheme.apply_weights(layer, x, bias=bias)
  File "/mnt/data/data/bbuf/sglang/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py", line 164, in apply_weights
    return apply_fp8_linear(
  File "/mnt/data/data/bbuf/sglang/python/sglang/srt/layers/quantization/fp8_utils.py", line 597, in apply_fp8_linear
    qinput, x_scale = scaled_fp8_quant(
  File "/mnt/data/data/bbuf/sglang/python/sglang/srt/layers/quantization/fp8_kernel.py", line 1448, in scaled_fp8_quant
    sgl_per_token_quant_fp8(input, output, scale)
  File "/usr/local/lib/python3.10/dist-packages/sgl_kernel/gemm.py", line 161, in sgl_per_token_quant_fp8
    torch.ops.sgl_kernel.sgl_per_token_quant_fp8.default(input, output_q, output_s)
  File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 829, in __call__
    return self._op(*args, **kwargs)
RuntimeError: Hidden dimension must be divisible by 8, but got 1076 

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @BBuf, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the flexibility of the sgl_per_token_quant_fp8 quantization function by extending its compatibility to a broader range of hidden dimensions. This change is crucial for supporting models, such as qwen3-30b-a3b-moe-fp8, that utilize layer configurations with hidden dimensions divisible by 4 but not by 8, thereby resolving a specific deployment issue and improving the utility of the quantization scheme.

Highlights

  • Expanded Hidden Dimension Support: The sgl_per_token_quant_fp8 function now supports hidden dimensions that are divisible by 4, relaxing the previous requirement for divisibility by 8.
  • New Kernel Code Path: A new kVecSize=4 code path has been added to both kernel variants (per_token_quant_fp8_kernel and per_token_quant_fp8_small_batch_kernel) to handle cases where hidden_dim % 4 == 0 but hidden_dim % 8 != 0.
  • Updated Test Coverage: Test cases have been updated to include hidden_dim=1076, specifically addressing the scenario that motivated this change.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@BBuf BBuf requested a review from yuan-luo November 8, 2025 13:08
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly extends sgl_per_token_quant_fp8 to support hidden dimensions divisible by 4, which is a necessary change for compatibility with certain models. The logic is sound, and the test cases have been updated appropriately to cover the new functionality. My review includes suggestions to refactor the kernel dispatch logic in sgl-kernel/csrc/gemm/per_token_quant_fp8.cu. These changes would reduce code duplication and improve the overall maintainability of the file.

Comment thread sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
Comment thread sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
@BBuf BBuf added the run-ci label Nov 8, 2025
@ispobock ispobock merged commit 05559a4 into main Nov 10, 2025
97 of 110 checks passed
@ispobock ispobock deleted the support_hidden_dim_mod_4_equals_zero_in_per_token_quant branch November 10, 2025 09:13
ocss884 pushed a commit to ocss884/sglang that referenced this pull request Nov 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants