feat: enable and update all-reduce fused quantization#1164
feat: enable and update all-reduce fused quantization#1164yzh119 merged 5 commits intoflashinfer-ai:mainfrom
Conversation
There was a problem hiding this comment.
Summary of Changes
Hello @yyihuang, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request focuses on enabling and updating fused quantization, specifically FP4, within the all-reduce and MoE all-reduce communication primitives. I've introduced new memory layouts and core CUDA device functions to handle FP4 quantization and scale factor generation, with the full FP4 data conversion now active in the MoE all-reduce path, and the infrastructure laid out for the general all-reduce path.
Highlights
- FP4 Quantization Infrastructure: I've enabled the foundational framework for FP4 quantization within
trtllm_allreduce_fusion.cuh. This includes defining new memory layouts for FP4 scale factors (FP4QuantizationSFLayout) and implementing CUDA device functions (get_sf_out_offset_128x4,cvt_quant_to_fp4_get_sf_out_offset,cvt_warp_fp16_to_fp4) for calculating scale factor offsets and performing the core FP16/BF16 to FP4 conversion. While the scale factor offset calculation path is active, the final FP4 data conversion and storage in this specific file remains commented out for now. - Full FP4 Quantization in MoE All-Reduce: I've fully enabled FP4 quantization in
trtllm_moe_allreduce_fusion.cuh. This means that both the scale factor offset calculation and the actual FP16/BF16 to FP4 data conversion and storage are now active for MoE all-reduce operations. - Code Generalization and Refactoring: I've refactored and generalized several CUDA device functions by introducing new template parameters (e.g.,
SF_VEC_SIZE) to make them more flexible and reusable across different quantization scenarios.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Code Review
This pull request enables and updates fused quantization in all-reduce and MoE all-reduce operations. The changes introduce new enums and device functions for handling FP4 quantization scale factors and conversions. The review identified critical compilation issues due to missing symbol definitions and type mismatches, as well as a high-severity correctness issue in the calculation of K-tiles for linear layouts. Addressing these issues will improve the robustness and correctness of the new quantization features.
| for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { | ||
| localMax = cuda_max(localMax, cuda_abs(vec.elts[i])); | ||
| } | ||
|
|
||
| constexpr int CVT_NUM_THREADS_PER_SF = SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD; | ||
| // Get the absolute maximum among all 16 values (two threads for 16, four threads for 32). | ||
| localMax = cuda_max(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); | ||
| if constexpr (CVT_NUM_THREADS_PER_SF == 4) { | ||
| localMax = cuda_max(__shfl_xor_sync(uint32_t(-1), localMax, 2), localMax); | ||
| } | ||
| // Get the final absolute maximum values. | ||
| float vecMax = float(cuda_max(localMax.x, localMax.y)); | ||
|
|
||
| // Get the SF (max value of the vector / max value of e2m1). | ||
| // maximum value of e2m1 = 6.0. | ||
| // TODO: use half as compute data type. | ||
| float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); | ||
| // 8 bits representation of the SF. | ||
| uint8_t fp8SFVal; | ||
| // Write the SF to global memory (STG.8). | ||
| if constexpr (UE8M0_SF) { | ||
| __nv_fp8_e8m0 tmp; | ||
| tmp.__x = __nv_cvt_float_to_e8m0(SFValue, __NV_SATFINITE, cudaRoundPosInf); | ||
| SFValue = static_cast<float>(tmp); | ||
| fp8SFVal = tmp.__x; | ||
| } else { | ||
| // Here SFValue is always positive, so E4M3 is the same as UE4M3. | ||
| __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); | ||
| fp8SFVal = tmp.__x; | ||
| SFValue = static_cast<float>(tmp); | ||
| } | ||
| // Get the output scale. | ||
| // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * reciprocal(SFScaleVal)) | ||
| float outputScale = | ||
| SFValue != 0 ? reciprocal_approximate_ftz(SFValue * reciprocal_approximate_ftz(SFScaleVal)) | ||
| : 0.0f; | ||
|
|
||
| if (SFout) { | ||
| // Write the SF to global memory (STG.8). | ||
| *SFout = fp8SFVal; | ||
| } | ||
|
|
||
| // Convert the input to float. | ||
| float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; | ||
|
|
||
| #pragma unroll | ||
| for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { | ||
| if constexpr (std::is_same_v<Type, half>) { | ||
| fp2Vals[i] = __half22float2(vec.elts[i]); | ||
| } else { | ||
| fp2Vals[i] = __bfloat1622float2(vec.elts[i]); | ||
| } | ||
| fp2Vals[i].x *= outputScale; | ||
| fp2Vals[i].y *= outputScale; | ||
| } | ||
|
|
||
| // Convert to e2m1 values. | ||
| uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals); | ||
|
|
||
| // Write the e2m1 values to global memory. | ||
| return e2m1Vec; |
There was a problem hiding this comment.
The cvt_warp_fp16_to_fp4 function uses CVT_FP4_ELTS_PER_THREAD, PackedVec, and fp32_vec_to_e2m1 which are not defined or included in this file. This will lead to compilation errors. These symbols appear to be defined in trtllm_moe_allreduce_fusion.cuh or vec_dtypes.cuh. Please ensure all necessary headers are included and symbols are in scope.
| template <class SFType, int CVT_FP4_NUM_THREADS_PER_SF, int SF_VEC_SIZE> | ||
| __device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(std::optional<int> batchIdx, int rowIdx, | ||
| int colIdx, std::optional<int> numRows, | ||
| int numCols, SFType* SFout, | ||
| FP4QuantizationSFLayout layout) { |
There was a problem hiding this comment.
The SFType template parameter for cvt_quant_to_fp4_get_sf_out_offset is uint32_t at the call site (line 528), but the function's return type is uint8_t* and the value written to *SFout (in cvt_warp_fp16_to_fp4) is uint8_t. This type mismatch can lead to incorrect memory access or undefined behavior. Consider changing SFType to uint8_t or void to align with the actual usage of SFout as a byte pointer.
template <int CVT_FP4_NUM_THREADS_PER_SF, int SF_VEC_SIZE>
__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(std::optional<int> batchIdx, int rowIdx,
int colIdx, std::optional<int> numRows,
int numCols, uint8_t* SFout,
FP4QuantizationSFLayout layout) {
| // Linear row-major layout, no padding required. | ||
| int32_t KTileIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; | ||
|
|
||
| int32_t numKTiles = numCols / SF_VEC_SIZE; |
There was a problem hiding this comment.
For the LINEAR layout, numKTiles is calculated using integer division (numCols / SF_VEC_SIZE). If numCols is not a multiple of SF_VEC_SIZE, this will truncate the result, leading to an incorrect number of K-tiles and an incorrect SFOffset. This should be a ceiling division to ensure all columns are accounted for.
int32_t numKTiles = (numCols + SF_VEC_SIZE - 1) / SF_VEC_SIZE;
| int32_t KTileIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; | ||
|
|
||
| int32_t numKTiles = numCols / details::CVT_FP4_SF_VEC_SIZE; | ||
| int32_t numKTiles = numCols / SF_VEC_SIZE; |
There was a problem hiding this comment.
Similar to trtllm_allreduce_fusion.cuh, the numKTiles calculation for the LINEAR layout uses integer division (numCols / SF_VEC_SIZE). This can lead to an incorrect SFOffset if numCols is not perfectly divisible by SF_VEC_SIZE. Please use ceiling division to correctly account for all columns.
int32_t numKTiles = (numCols + SF_VEC_SIZE - 1) / SF_VEC_SIZE;
📌 Description
We enabled and updated the fused quantization in all-reduce/moe all-reduce.
🔍 Related Issues
Depends on #1142
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes