tuner: Trtllm-gen Fp4 MoE Autotunner#1475
Conversation
There was a problem hiding this comment.
Summary of Changes
Hello @IwakuraRein, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces an autotuning mechanism for FP4 Mixture-of-Experts (MoE) kernels generated by TensorRT-LLM (trtllm-gen). It enables the system to dynamically select optimal kernel configurations for improved performance. The changes involve extending the C++ kernel interface to accept specific configurations, implementing a Python-based MoERunner to manage and cache valid configurations, and integrating this runner with the existing autotuner framework. Additionally, the tensor creation logic for autotuning is enhanced to use more realistic random data, which can lead to more accurate tuning results.
Highlights
- Autotuning for FP4 MoE Kernels: A comprehensive autotuning framework has been introduced for
trtllm-genFP4 Mixture-of-Experts (MoE) kernels, allowing the system to dynamically find and apply optimal kernel configurations for enhanced performance. - Dynamic Kernel Configuration: The underlying C++
trtllm_fp4_block_scale_moefunction now accepts aconfig_index, enabling the dynamic selection of specific, pre-optimized kernel configurations at runtime. MoERunnerImplementation: A newMoERunnerclass has been implemented to manage and cache valid kernel configurations. This includes logic for calculating optimaltile_tokens_dimand querying available tactics from the C++ backend, streamlining the tuning process.- Enhanced Autotuner Input Data: The
_create_tensor_likeutility in the autotuner has been updated to generate random data for integer and floating-point types, providing more realistic input tensors during tuning and potentially leading to more accurate performance optimizations. - Integration with
AutoTuner: The Python frontend functions (trtllm_fp4_block_scale_moe_op,trtllm_fp4_block_scale_moe,trtllm_fp4_block_scale_routed_moe) have been modified to seamlessly integrate with theAutoTunerandMoERunner, ensuring that the best-performing kernel configuration is automatically selected and used.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Code Review
This pull request introduces an autotuner for the Trtllm-gen FP4 MoE kernels, which is a valuable addition for performance optimization. The changes span both the C++ backend and the Python-level autotuner infrastructure. The overall approach is sound, but there are a few issues that need to be addressed, including a critical bug that could cause a runtime error, some minor code correctness issues, and leftover debugging code. Addressing these points will improve the robustness and quality of the implementation.
There was a problem hiding this comment.
The variables hidden_states_scale, gemm1_weights_scale, and gemm2_weights_scale are only defined within if blocks. If the conditions are false, these variables will not be defined, leading to a NameError when they are used in the moe_op.trtllm_fp4_block_scale_moe call on line 1094. You should initialize them to None before the conditional blocks to ensure they are always defined.
extra_input_idx = 0
hidden_states_scale = None
if trtllm_gen_dtype_has_scale(self.dtype_act):
hidden_states_scale = extra_inputs[extra_input_idx]
extra_input_idx += 1
gemm1_weights_scale = None
gemm2_weights_scale = None
if trtllm_gen_dtype_has_scale(self.dtype_weights):
gemm1_weights_scale = extra_inputs[extra_input_idx]
gemm2_weights_scale = extra_inputs[extra_input_idx + 1]
extra_input_idx += 2There was a problem hiding this comment.
e4fb808 to
e533cac
Compare
|
cc @amirkl94 |
4d0b914 to
37316e9
Compare
|
|
||
| inline btg::Dtype get_dtype(int64_t const dtype) { | ||
| switch (dtype) { | ||
| case 0: |
There was a problem hiding this comment.
this currenly finds the type by Dtype's Uid bits.
would it be more maintainable if we simply pass strings (so in code review, the case will more obviously match the type name) and document at the binding interface that the available options are in DtypeDecl.h Dtype
There was a problem hiding this comment.
the current approach has the advantage of being checked early using class DtypeTrtllmGen(IntEnum). so no objections, just raising options
There was a problem hiding this comment.
@aleozlx Thanks for the suggestion! What about defining a new enum in trtllm_fused_moe_kernel_launcher.cu and use macros to map it to btg::Dtype. Then expose to python.
There was a problem hiding this comment.
sure . that improves readability. or we can use class Dtype(IntEnum) i proposed in the other thread and remove another conversion. and hopefully __new__ serves to hide impl details in a way that user won't rely on the value (compared to a separate function). although the value changes can cause breakage thru the cpp interface, so we just have to align on something early.
|
|
||
| # NOTE(siyuan): Need to keep this in sync with the counterpart defined in include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h | ||
| class DtypeTrtllmGen(IntEnum): | ||
| Bfloat16 = (0,) |
There was a problem hiding this comment.
or we can borrow the bits formation from TLLM_ENCODE_DTYPE, and allowing the deletion of get_dtype() conversion
e.g.
class Dtype(IntEnum):
def __new__(cls, block_format_bit, signed_bit, integer_bit, num_bits, uid):
value = (block_format_bit << 24) | (signed_bit << 20) | (integer_bit << 16) | (num_bits << 8) | uid
obj = int.__new__(cls, value)
obj._value_ = value
return obj
# keep the values in sync with include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h
Bfloat16 = (0, 1, 0, 16, 0)
Bool = (0, 0, 1, 1, 1)
E2m1 = (1, 1, 0, 4, 2)
E2m3 = (1, 1, 0, 6, 3)
E3m2 = (1, 1, 0, 6, 4)
E4m3 = (0, 1, 0, 8, 5)
E5m2 = (0, 1, 0, 8, 6)
Fp16 = (0, 1, 0, 16, 7)
Fp32 = (0, 1, 0, 32, 8)
Int8 = (0, 1, 1, 8, 9)
Int32 = (0, 1, 1, 32, 10)
Int64 = (0, 1, 1, 64, 11)
MxE2m1 = (1, 1, 0, 4, 12)
MxE4m3 = (1, 1, 0, 8, 13)
UE8m0 = (0, 0, 0, 8, 14)
UInt8 = (0, 0, 1, 8, 15)
UInt16 = (0, 0, 1, 16, 16)
UInt32 = (0, 0, 1, 32, 17)
UInt64 = (0, 0, 1, 64, 18)
UInt128 = (0, 0, 1, 128, 19)
Void = (0, 1, 0, 0, 20)| routing_method_type: int = 0, | ||
| do_finalize: bool = True, | ||
| enable_pdl: Optional[bool] = None, | ||
| tune_max_num_tokens: int = 1024, |
There was a problem hiding this comment.
shall we always append? so if users are using the optional arguments positionally, they won't break.
or i guess we should have put optionals after , /, ... in hind sight to prevent them from being used positionally, which is easier to break
| dtype = DtypeTrtllmGen.Bfloat16 | ||
| elif x.dtype == torch.float8_e4m3fn: | ||
| dtype = DtypeTrtllmGen.E4m3 if scale is None else DtypeTrtllmGen.MxE4m3 | ||
| elif x.dtype == torch.uint8: |
There was a problem hiding this comment.
Just a minor note: we should also take care of torch.float4_e2m1x2 for torch 2.8+
There was a problem hiding this comment.
Thanks. But I didn't see float4_e2m1fn_x2 is used anywhere in the flashinfer? Will we add it altogether in the future?
There was a problem hiding this comment.
Not yet, we need to prepare for it when frameworks are all upgrading to torch 2.8.
It could be done in later PRs.
|
There are some conflicts with main branch after #1396 got merged, would you mind rebasing? |
| assert hidden_states.shape[0] == num_tokens, ( | ||
| "hidden_states's first dimension must be batch size." | ||
| ) | ||
| assert hidden_states_scale is None or ( |
There was a problem hiding this comment.
This assertion make vllm gpt-oss fail. In that case the scale is tensor([1., 1., 1., ..., 1., 1., 1.], device='cuda:0',dtype=torch.float8_e4m3fn)
There was a problem hiding this comment.
@weireweire I created this pr to fix this. Currently vllm's flashinfer tag is 0.2.12. I can mark this pr as ready and increment the flashinfer tag to 0.2.13
|
|
||
| # TODO(siyuan): support fp8 | ||
| moe_op.trtllm_fp4_block_scale_moe( | ||
| routing_logits.to(torch.bfloat16), |
There was a problem hiding this comment.
This line actually breaks DeepSeek v3 routing. I have left suggest changes in #1494
📌 Description
AutoTuner,OptimizationProfile, andDynamicTensorSpec.DynamicTensorSpeccan take multiple input tensors.tensor_initializersinDynamicTensorSpecdefines the initialization method for dynamic tensors. Before they were all zero-initialized and this will cause IMA in trtllm-gen's routing kernels.DtypeTrtllmGeninflashinfer/fused_moe/core.pyhidden_states_scalesin trtllm-gen fp4 moe. It doesn't need to be 1D.TODOs
Performance
B200, clock speed locked at 1500mhz, 1000 warmups, 1000 iterations,
mxfp4 x mxfp8For
nvfp4 x nvfp4andmxfp4 x bf16, there is no significant perf gain.🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes