[DRAFT] Support dispatch low latency#10120
Conversation
There was a problem hiding this comment.
Summary of Changes
Hello @kaixih, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request implements a new low-latency dispatch mechanism for Mixture-of-Experts (MoE) layers by integrating FlashInfer's CuteDSL kernels. The changes enhance the DeepEP configuration to support Bfloat16 dispatch and incorporate global scaling factors for improved handling of quantized operations. This work aims to optimize the performance of MoE computations within the system.
Highlights
- New MoE Dispatch Backend: Introduces FLASHINFER_CUTEDSL as a new backend for Mixture-of-Experts (MoE) dispatch, leveraging FlashInfer's CuteDSL kernels for potentially lower latency.
- Bfloat16 Dispatch Support: Adds an environment variable SGLANG_DEEPEP_BF16_DISPATCH to enable Bfloat16 precision for dispatch operations within the DeepEP configuration.
- Quantization Scale Integration: Modifies the dispatch process to pass input global scales, crucial for accurate quantized operations, particularly with FP4 weights.
- New Test Suite: Includes a new test file (test_flashinfer_cutedsl_scaled_mm.py) to validate the correctness of the FlashInfer CuteDSL grouped GEMM and MoE masked operations.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Code Review
This pull request adds support for a new low-latency dispatch mechanism for Mixture-of-Experts (MoE) layers, utilizing FlashInfer's CuteDSL kernels. This is exposed via a new flashinfer_cutedsl backend for the MoE runner. The changes include adding the new backend option, implementing the corresponding MoE computation path with FP4 quantization, and adding a new test file for the functionality. My review focuses on the correctness and consistency of these changes. I've found a critical issue in the new test file where an argument is passed with an incorrect type, which would lead to a runtime error. I've also identified some minor inconsistencies in docstrings and assertion messages that should be fixed for better maintainability.
| out = flashinfer_cutedsl_moe_masked( | ||
| hidden_states_3d.to(hidden_states.device), | ||
| input_global_scale, | ||
| w1_fp4.permute(2, 0, 1), | ||
| w1_blockscale, | ||
| w1_alpha, | ||
| w2_fp4.permute(2, 0, 1), | ||
| a2_global_scale, | ||
| w2_blockscale, | ||
| w2_alpha, | ||
| masked_m.to(hidden_states.device), | ||
| ) |
There was a problem hiding this comment.
The flashinfer_cutedsl_moe_masked function expects hidden_states to be a tuple of two tensors (quantized data and scales), but hidden_states_3d which is a single tensor is passed. This will cause a runtime error. The hidden_states_3d tensor should be quantized before being passed to the function, for example by using scaled_fp4_grouped_quant.
| Args: | ||
| hidden_states (tuple[torch.Tensor, torch.Tensor]): [num_experts, m, k // 2], uint8, [num_experts, m, k // 16], float8_e4m3fn | ||
| input_global_scale (torch.Tensor): (l,) | ||
| w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8 | ||
| w1_blockscale (torch.Tensor): blockscale factors, e4m3, | ||
| w1_alpha (torch.Tensor): (l,) | ||
| w2 (torch.Tensor): fp4 weights, [l, k, n // 2], uint8 | ||
| a2_global_scale (torch.Tensor): (l,) | ||
| w2_blockscale (torch.Tensor): blockscale factors, e4m3, | ||
| w2_alpha (torch.Tensor): (l,) | ||
| masked_m (torch.Tensor): Masked dimension indices |
There was a problem hiding this comment.
| assert input_global_scale.shape == ( | ||
| num_experts, | ||
| ), f"input_global_scale must be (l,), got {input_global_scale.shape}" | ||
| assert w1_alpha.shape == ( | ||
| num_experts, | ||
| ), f"w1_alpha must be (l,), got {w1_alpha.shape}" | ||
| assert a2_global_scale.shape == ( | ||
| num_experts, | ||
| ), f"a2_global_scale must be (l,), got {a2_global_scale.shape}" | ||
| assert w2_alpha.shape == ( | ||
| num_experts, | ||
| ), f"w2_alpha must be (l,), got {w2_alpha.shape}" |
There was a problem hiding this comment.
The assertion messages use (l,) to describe the expected shape, but the variable l is not defined in this scope. The code actually checks against num_experts. For consistency and clarity, please use num_experts in the assertion messages.
| assert input_global_scale.shape == ( | |
| num_experts, | |
| ), f"input_global_scale must be (l,), got {input_global_scale.shape}" | |
| assert w1_alpha.shape == ( | |
| num_experts, | |
| ), f"w1_alpha must be (l,), got {w1_alpha.shape}" | |
| assert a2_global_scale.shape == ( | |
| num_experts, | |
| ), f"a2_global_scale must be (l,), got {a2_global_scale.shape}" | |
| assert w2_alpha.shape == ( | |
| num_experts, | |
| ), f"w2_alpha must be (l,), got {w2_alpha.shape}" | |
| assert input_global_scale.shape == ( | |
| num_experts, | |
| ), f"input_global_scale must be (num_experts,), got {input_global_scale.shape}" | |
| assert w1_alpha.shape == ( | |
| num_experts, | |
| ), f"w1_alpha must be (num_experts,), got {w1_alpha.shape}" | |
| assert a2_global_scale.shape == ( | |
| num_experts, | |
| ), f"a2_global_scale must be (num_experts,), got {a2_global_scale.shape}" | |
| assert w2_alpha.shape == ( | |
| num_experts, | |
| ), f"w2_alpha must be (num_experts,), got {w2_alpha.shape}" |
|
FYI my future dev work will be in: feat/deepep_ll_nvfp4 #10263 |
|
well gsm8k for this branch is zero :/ |
|
Closing. Since the new work in the #10263 and is already merged. |
WIP
cc. @kushanam @wenscarl @fzyzcjy