Skip to content

Fix for fp8 quantization failure of qwen 2.5 VL 7B model.#7448

Open
PanJason wants to merge 2 commits intosgl-project:mainfrom
PanJason:qwen2_5_vl_fp8_fix
Open

Fix for fp8 quantization failure of qwen 2.5 VL 7B model.#7448
PanJason wants to merge 2 commits intosgl-project:mainfrom
PanJason:qwen2_5_vl_fp8_fix

Conversation

@PanJason
Copy link
Copy Markdown
Contributor

@PanJason PanJason commented Jun 22, 2025

Motivation

Fix this bug #6828 of qwen 2.5 VL 7B where fp8 quantization was reported to fail in the visual model part.

Modifications

As written in the commit message, two modifications are performed.

  1. Make per_token_quant_fp8.cu work with non-16-aligned hidden dim. I split the non-aligned hidden dim into 3 parts, the prologue part, the aligned part, and the tail part, since the address may start from a non-128-bit aligned address.
  2. Make fp8_gemm_kernel work with non-16-aligned K and N supposing the mat mul is [M, K] x [K, N]. The extension of N is unlikely to trigger perf regression as we can just initialize the output matrix to be larger but aligned and return a view. The unaligned K is the problematic part (hidden dim 3420 in case of qwen2.5 VL 7B visual LLM). In this PR, the solution is to initialize two empty matrices whose K is aligned and perform two additional copies from the original mat_a and mat_b respectively. I don't know where this is appropriate, as it slows down the inference and requires double the memory. Two other options are available as I stated in the commit message:
    a. Initialize the weight matrix directly to have a 16-aligned (3424 in this case) hidden dim with zero padding.
    b. Disable quantization for the vision model part of qwen 2.5 VL.

Feedback is welcome. Thanks!

Checklist

PanJason added 2 commits June 22, 2025 21:57
Chop the unaligned hidden dim into 3 parts
1. Prologue part: start from unaligned end with aligned
2. Main loop: start with aligned, end with aligned
3. Tail part: start with aligned, end with unaligned.
For unaligned N, the kernel generate the output tensor with aligned
dim which ideally will not give too much perf penalty

For unaligned K, the kernel has to initialize temporary new tensor
for both A and B and performs two additional copies. There are 2
alternatives here:
1. In weight initialization, initialize weight matrix to have a hidden
dim which is 16 aligned with 0 padding for qwen2.5 VL 7B
2. Disable quantization for vision part

Feedback is welcome for which one of three to use.
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @PanJason, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request provides a critical fix for FP8 quantization failures that occur when processing models with non-16-aligned tensor dimensions, such as the Qwen 2.5 VL 7B model. The changes ensure that both the per-token quantization and the core FP8 matrix multiplication kernels can correctly handle unaligned data, improving the robustness and compatibility of the FP8 quantization pipeline across various model architectures.

Highlights

  • FP8 Quantization Alignment Fix: Implemented robust handling for non-16-aligned dimensions in both per-token FP8 quantization and FP8 GEMM operations. This specifically addresses issues encountered with models like Qwen 2.5 VL 7B, which have unaligned hidden dimensions (e.g., 3420).
  • Per-Token Quantization Enhancement: Modified the per_token_quant_fp8_kernel to correctly process input tensors with hidden dimensions that are not 16-byte aligned. This involves splitting the processing into prologue, aligned main part, and tail sections for both finding the maximum value and performing the quantization, removing previous strict alignment requirements.
  • FP8 GEMM Flexibility: Updated the fp8_gemm_kernel to support non-16-aligned K (inner product dimension) and N (output column dimension) in matrix multiplications. This is achieved by dynamically padding the input and output tensors to the required 16-byte alignment before computation and then returning a view of the original, unpadded result.
  • Test Coverage Expansion: Expanded unit tests for both FP8 GEMM and per-token quantization to include non-aligned dimensions (specifically 3420), ensuring the new alignment handling logic is thoroughly validated.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

Thank you for this contribution! This PR addresses an important issue with FP8 quantization for models with non-16-aligned dimensions.

The changes in per_token_quant_fp8.cu to handle unaligned hidden_dim are a good direction, but I've found a critical issue with the alignment logic for vectorized memory access in the quantization step. Please see my detailed comment on that.

Regarding fp8_gemm_kernel.cu, the approach of padding at runtime by creating new tensors and copying data is a valid fix to get things working. As you noted in the PR description, this does introduce performance and memory overhead. If this function is in a performance-critical path, it would be much more efficient to perform this padding once at model loading time rather than on every inference call. The current implementation is a good trade-off for a quick fix, but for production use, the load-time padding approach should be considered.

I've also left a few other minor suggestions for improving code quality and maintainability. The test case additions are great and directly target the problem being solved.

Let me know if you have any questions about my feedback!

Comment on lines +119 to +121
assert(
reinterpret_cast<size_t>(token_input_for_aligned_output) % ALIGNMENT_BYTES_OUTPUT == 0 &&
"Input address must be 16-byte aligned.");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This assert correctly identifies a potential logic flaw. The quantization step (Step 2) is structured based on the alignment of the output pointer (token_output_base), but the main loop performs vectorized loads from the input pointer (token_input_for_aligned_output).

The alignment of token_input_for_aligned_output is not guaranteed because prologue_output_elements_count is calculated based on the output pointer's alignment. If the input and output pointers have different alignment offsets (which is possible with an unaligned hidden_dim), the vectorized loads from the input will be unaligned. This can cause significant performance degradation or even crashes on some hardware.

A safer approach is to structure the quantization step's loops based on the input pointer's alignment, similar to Step 1. For writes to the output, you can then handle potential unalignment, for example, by using scalar writes within an unrolled loop if a vectorized write is not possible.

Example of a safer write within the vectorized loop (assuming the loop is structured based on input alignment):

// In the main vectorized loop for quantization
// ... (vectorized read and computation)
FP8_TYPE* current_output_ptr = token_output_base + prologue_elements_count + i * VEC_SIZE;
// Use scalar writes if output is not guaranteed to be aligned for a vectorized store.
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
    current_output_ptr[j] = output_arr[j];
}

Please fix the underlying logic to ensure all vectorized memory accesses are properly aligned.

Comment on lines +1133 to +1140
size_t element_size_out = 0;
if (out_dtype == torch::kHalf) {
element_size_out = sizeof(cutlass::half_t);
} else if (out_dtype == torch::kBFloat16) {
element_size_out = sizeof(cutlass::bfloat16_t);
} else {
TORCH_CHECK(false, "Unsupported output dtype");
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block can be simplified by using torch::elementSize(out_dtype). This is more concise and less error-prone if more types are supported in the future. The check for supported out_dtype is already performed on line 1088, so the else branch with TORCH_CHECK is redundant.

  size_t element_size_out = torch::elementSize(out_dtype);

Comment on lines +143 to 156
// Scalar loop for the final (unaligned) output tail elements
if (num_tail_elems_output > 0) {
const int32_t tail_output_start_idx = prologue_output_elements_count + num_full_vecs_output * VEC_SIZE;
for (int32_t i = tail_output_start_idx + tid; i < hidden_dim; i += block_dim) {
float val = fmaxf(fminf(static_cast<float>(token_input_base[i]) * scale_inv, FP8_E4M3_MAX), -FP8_E4M3_MAX);
#ifndef USE_ROCM
token_output_base[i] = static_cast<FP8_TYPE>(val);
#else
token_output_base[i] = c10::Float8_e4m3fnuz(
__hip_cvt_float_to_fp8(val, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret),
c10::Float8_e4m3fnuz::from_bits());
#endif
}
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is some code duplication between the scalar processing loops. For instance, the logic in this tail loop is very similar to the prologue loop (lines 99-109). The same applies to the find_max step.

To improve maintainability, consider refactoring this by creating a __device__ helper function that performs the quantization for a single element. This function could then be called from all scalar loops, reducing code duplication.

@Alcanderian
Copy link
Copy Markdown
Collaborator

Hi @PanJason , we prefer solution a(Initialize the weight matrix directly to have a 16-aligned (3424 in this case) hidden dim with zero padding.) for performance purpose.

@PanJason
Copy link
Copy Markdown
Contributor Author

Hi @Alcanderian , thanks for your feedback. I will prepare the next version in a few days.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants