Skip to content

feat: enable and update all-reduce fused quantization#1164

Merged
yzh119 merged 5 commits intoflashinfer-ai:mainfrom
yyihuang:ar-quant
Jul 2, 2025
Merged

feat: enable and update all-reduce fused quantization#1164
yzh119 merged 5 commits intoflashinfer-ai:mainfrom
yyihuang:ar-quant

Conversation

@yyihuang
Copy link
Copy Markdown
Collaborator

@yyihuang yyihuang commented Jun 22, 2025

📌 Description

We enabled and updated the fused quantization in all-reduce/moe all-reduce.

🔍 Related Issues

Depends on #1142

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @yyihuang, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on enabling and updating fused quantization, specifically FP4, within the all-reduce and MoE all-reduce communication primitives. I've introduced new memory layouts and core CUDA device functions to handle FP4 quantization and scale factor generation, with the full FP4 data conversion now active in the MoE all-reduce path, and the infrastructure laid out for the general all-reduce path.

Highlights

  • FP4 Quantization Infrastructure: I've enabled the foundational framework for FP4 quantization within trtllm_allreduce_fusion.cuh. This includes defining new memory layouts for FP4 scale factors (FP4QuantizationSFLayout) and implementing CUDA device functions (get_sf_out_offset_128x4, cvt_quant_to_fp4_get_sf_out_offset, cvt_warp_fp16_to_fp4) for calculating scale factor offsets and performing the core FP16/BF16 to FP4 conversion. While the scale factor offset calculation path is active, the final FP4 data conversion and storage in this specific file remains commented out for now.
  • Full FP4 Quantization in MoE All-Reduce: I've fully enabled FP4 quantization in trtllm_moe_allreduce_fusion.cuh. This means that both the scale factor offset calculation and the actual FP16/BF16 to FP4 data conversion and storage are now active for MoE all-reduce operations.
  • Code Generalization and Refactoring: I've refactored and generalized several CUDA device functions by introducing new template parameters (e.g., SF_VEC_SIZE) to make them more flexible and reusable across different quantization scenarios.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables and updates fused quantization in all-reduce and MoE all-reduce operations. The changes introduce new enums and device functions for handling FP4 quantization scale factors and conversions. The review identified critical compilation issues due to missing symbol definitions and type mismatches, as well as a high-severity correctness issue in the calculation of K-tiles for linear layouts. Addressing these issues will improve the robustness and correctness of the new quantization features.

Comment on lines +187 to +247
for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
localMax = cuda_max(localMax, cuda_abs(vec.elts[i]));
}

constexpr int CVT_NUM_THREADS_PER_SF = SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD;
// Get the absolute maximum among all 16 values (two threads for 16, four threads for 32).
localMax = cuda_max(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax);
if constexpr (CVT_NUM_THREADS_PER_SF == 4) {
localMax = cuda_max(__shfl_xor_sync(uint32_t(-1), localMax, 2), localMax);
}
// Get the final absolute maximum values.
float vecMax = float(cuda_max(localMax.x, localMax.y));

// Get the SF (max value of the vector / max value of e2m1).
// maximum value of e2m1 = 6.0.
// TODO: use half as compute data type.
float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f));
// 8 bits representation of the SF.
uint8_t fp8SFVal;
// Write the SF to global memory (STG.8).
if constexpr (UE8M0_SF) {
__nv_fp8_e8m0 tmp;
tmp.__x = __nv_cvt_float_to_e8m0(SFValue, __NV_SATFINITE, cudaRoundPosInf);
SFValue = static_cast<float>(tmp);
fp8SFVal = tmp.__x;
} else {
// Here SFValue is always positive, so E4M3 is the same as UE4M3.
__nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue);
fp8SFVal = tmp.__x;
SFValue = static_cast<float>(tmp);
}
// Get the output scale.
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * reciprocal(SFScaleVal))
float outputScale =
SFValue != 0 ? reciprocal_approximate_ftz(SFValue * reciprocal_approximate_ftz(SFScaleVal))
: 0.0f;

if (SFout) {
// Write the SF to global memory (STG.8).
*SFout = fp8SFVal;
}

// Convert the input to float.
float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2];

#pragma unroll
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
if constexpr (std::is_same_v<Type, half>) {
fp2Vals[i] = __half22float2(vec.elts[i]);
} else {
fp2Vals[i] = __bfloat1622float2(vec.elts[i]);
}
fp2Vals[i].x *= outputScale;
fp2Vals[i].y *= outputScale;
}

// Convert to e2m1 values.
uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals);

// Write the e2m1 values to global memory.
return e2m1Vec;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The cvt_warp_fp16_to_fp4 function uses CVT_FP4_ELTS_PER_THREAD, PackedVec, and fp32_vec_to_e2m1 which are not defined or included in this file. This will lead to compilation errors. These symbols appear to be defined in trtllm_moe_allreduce_fusion.cuh or vec_dtypes.cuh. Please ensure all necessary headers are included and symbols are in scope.

Comment on lines +137 to +141
template <class SFType, int CVT_FP4_NUM_THREADS_PER_SF, int SF_VEC_SIZE>
__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(std::optional<int> batchIdx, int rowIdx,
int colIdx, std::optional<int> numRows,
int numCols, SFType* SFout,
FP4QuantizationSFLayout layout) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The SFType template parameter for cvt_quant_to_fp4_get_sf_out_offset is uint32_t at the call site (line 528), but the function's return type is uint8_t* and the value written to *SFout (in cvt_warp_fp16_to_fp4) is uint8_t. This type mismatch can lead to incorrect memory access or undefined behavior. Consider changing SFType to uint8_t or void to align with the actual usage of SFout as a byte pointer.

template <int CVT_FP4_NUM_THREADS_PER_SF, int SF_VEC_SIZE>
__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(std::optional<int> batchIdx, int rowIdx,
                                                       int colIdx, std::optional<int> numRows,
                                                       int numCols, uint8_t* SFout,
                                                       FP4QuantizationSFLayout layout) {

// Linear row-major layout, no padding required.
int32_t KTileIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;

int32_t numKTiles = numCols / SF_VEC_SIZE;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

For the LINEAR layout, numKTiles is calculated using integer division (numCols / SF_VEC_SIZE). If numCols is not a multiple of SF_VEC_SIZE, this will truncate the result, leading to an incorrect number of K-tiles and an incorrect SFOffset. This should be a ceiling division to ensure all columns are accounted for.

      int32_t numKTiles = (numCols + SF_VEC_SIZE - 1) / SF_VEC_SIZE;

int32_t KTileIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;

int32_t numKTiles = numCols / details::CVT_FP4_SF_VEC_SIZE;
int32_t numKTiles = numCols / SF_VEC_SIZE;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to trtllm_allreduce_fusion.cuh, the numKTiles calculation for the LINEAR layout uses integer division (numCols / SF_VEC_SIZE). This can lead to an incorrect SFOffset if numCols is not perfectly divisible by SF_VEC_SIZE. Please use ceiling division to correctly account for all columns.

      int32_t numKTiles = (numCols + SF_VEC_SIZE - 1) / SF_VEC_SIZE;

@yyihuang yyihuang marked this pull request as ready for review July 1, 2025 23:53
@yyihuang yyihuang requested a review from yzh119 July 1, 2025 23:54
@yzh119 yzh119 merged commit ef197c0 into flashinfer-ai:main Jul 2, 2025
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants