[lora][moe] Decoupled LoRA MoE backend with Marlin support#21858
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces LoRA support for Mixture-of-Experts (MoE) models by implementing a hook-based injection mechanism. It adds a new Marlin-based MoE runner for LoRA, updates existing runners to support LoRA hooks, and introduces virtual expert computation for improved efficiency. I have identified critical issues regarding dimension calculations in the CUDA buffer allocation logic that could lead to runtime errors, as well as opportunities to improve type safety by replacing 'Any' with 'LoRAHooks' in method signatures.
| _, N, _ = qinfo.w13_qweight.shape | ||
| hidden_dim = qinfo.w2_qweight.shape[1] |
There was a problem hiding this comment.
There appears to be an issue with how dimensions are calculated for the Marlin backend, which could lead to incorrect CUDA buffer allocations and potential runtime errors.
Nis being assigned the packed dimension2*N_packedfromqinfo.w13_qweight.shape. However, it's used to allocateintermediate_cache1, which requires the fullgate_up_dim. It should be multiplied by 16 to get the unpacked dimension.hidden_dimis being assignedqinfo.w2_qweight.shape[1], which is the packed intermediate dimension (N_packed), not the hidden dimension.
To fix this, you should use the unpacked dimensions for allocation.
| _, N, _ = qinfo.w13_qweight.shape | |
| hidden_dim = qinfo.w2_qweight.shape[1] | |
| N = qinfo.w13_qweight.shape[1] * 16 | |
| hidden_dim = qinfo.w13_qweight.shape[2] |
|
|
||
| M, K = hidden_states.shape | ||
| E = quant_info.w13_qweight.shape[0] | ||
| N = quant_info.w2_qweight.shape[1] * 16 |
There was a problem hiding this comment.
The calculation for N (intermediate dimension) seems incorrect. quant_info.w2_qweight has a shape of [E, N_packed, K], where E is number of experts, N_packed is the packed intermediate dimension, and K is the hidden dimension. Therefore, shape[1] corresponds to N_packed. The correct way to get the intermediate dimension N would be to use shape[2] from w13_qweight or w2_qweight and unpack it, or use w13_qweight.shape[1] and unpack it.
Given the shapes, quant_info.w2_qweight.shape[2] would give the hidden dimension, which is also not correct for N. The intermediate dimension N should be derived from w13_qweight's second dimension, unpacked.
| N = quant_info.w2_qweight.shape[1] * 16 | |
| N = quant_info.w13_qweight.shape[1] // 2 * 16 |
| runner_input: RunnerInput, | ||
| quant_info: MoeQuantInfo, | ||
| running_state: dict, | ||
| hooks: Optional[Any] = None, |
There was a problem hiding this comment.
For better type safety and maintainability, consider using the specific LoRAHooks type instead of Any for the hooks parameter. This makes the expected type explicit and improves code clarity.
You would need to add a forward reference import within the TYPE_CHECKING block:
if TYPE_CHECKING:
from sglang.srt.lora.lora_moe_runners import LoRAHooks
...And then change the signature to:
def run(
self,
runner_input: RunnerInput,
quant_info: MoeQuantInfo,
running_state: dict,
hooks: Optional["LoRAHooks"] = None,
) -> RunnerOutput:| hooks: Optional[Any] = None, | |
| hooks: Optional["LoRAHooks"] = None, |
| runner_input: DeepGemmRunnerInput, | ||
| quant_info: DeepGemmMoeQuantInfo, | ||
| running_state: dict, | ||
| hooks: Optional[Any] = None, |
There was a problem hiding this comment.
For better type safety and maintainability, you can replace Any with a more specific type hint for the hooks parameter. Using "LoRAHooks" as a forward reference would make the code clearer about the expected object type.
This would involve adding an import inside the TYPE_CHECKING block:
if TYPE_CHECKING:
from sglang.srt.lora.lora_moe_runners import LoRAHooks
...And updating the method signature accordingly.
| hooks: Optional[Any] = None, | |
| hooks: Optional["LoRAHooks"] = None, |
| runner_input: TritonRunnerInput, | ||
| quant_info: TritonMoeQuantInfo, | ||
| running_state: dict, | ||
| hooks: Optional[Any] = None, |
There was a problem hiding this comment.
To improve type safety, consider replacing Any with the specific "LoRAHooks" type hint for the hooks parameter. This will make the expected interface for hooks explicit.
You'll need to add from sglang.srt.lora.lora_moe_runners import LoRAHooks inside the TYPE_CHECKING block at the top of the file.
| hooks: Optional[Any] = None, | |
| hooks: Optional["LoRAHooks"] = None, |
| runner_input: TritonKernelsRunnerInput, | ||
| quant_info: TritonKernelsQuantInfo, | ||
| running_state: dict, | ||
| hooks: Optional[Any] = None, |
There was a problem hiding this comment.
For better type safety and code clarity, it's recommended to use the specific "LoRAHooks" type hint instead of Any for the hooks parameter, even if it's not used in this particular implementation. This maintains API consistency and improves readability.
You would need to add from sglang.srt.lora.lora_moe_runners import LoRAHooks to the TYPE_CHECKING block.
| hooks: Optional[Any] = None, | |
| hooks: Optional["LoRAHooks"] = None, |
8bec729 to
bc9a320
Compare
484c581 to
4291c5a
Compare
Refactor LoRA MoE runner from per-backend subclass (TritonRunnerCoreWithLoRA) to a generic hook-based injection pattern, decoupling LoRA logic from the base MoE backend. Add Marlin int4/int8 MoE backend for LoRA.
4291c5a to
220d2d5
Compare
|
/tag-run-ci-label |
709045c to
a3ddc91
Compare
|
/tag-run-ci-label |
|
/rerun-failed-ci |
1 similar comment
|
/rerun-failed-ci |
|
/rerun-failed-ci |
1 similar comment
|
/rerun-failed-ci |
|
/rerun-failed-ci |
|
Test Suite: stage-b-test-4-gpu-b200
|
|
/rerun-failed-ci |
1 similar comment
|
/rerun-failed-ci |
|
/tag-run-ci-label |
|
/rerun-failed-ci |
|
/rerun-stage stage-c-test-4-gpu-b200 |
|
/rerun-stage stage-c-test-4-gpu-H100 |
|
✅ Triggered |
|
❌ Stage NVIDIA stages:
AMD stages:
Other stages will be added soon. For now, use |
|
/rerun-stage stage-c-test-8-gpu-h200 |
|
✅ Triggered |
|
/rerun-failed-ci |
|
|
||
| assert self.runner_core is not None | ||
|
|
||
| def _maybe_build_lora_hooks(_runner_input: Any) -> LoRAHooks: |
There was a problem hiding this comment.
@klshuster @yushengsu-thu Do not define a function in the forward / run critical path. Clean this up!
Motivation
LoRA adapters applied to MoE layers currently couple LoRA injection logic into each backend-specific runner subclass (e.g.,
TritonRunnerCoreWithLoRA), making it difficult to add new backends. This PR:Modifications
Decoupled LoRA/MoE backends:
lora/lora_moe_runners.pyfrom a class-basedTritonRunnerCoreWithLoRA(which replaced the runner) to a hooks-based architecture (LoRAHooks,build_lora_hooks). Hooks are injected into the MoE runner'sRunnerInput, decoupling LoRA from the base MoE backend.MoeRunnernow acceptslora_enabledflag andpre_run_hook/post_run_hookinRunnerInput.compressed_tensorsquantization scheme updated to expose bothget_triton_quant_infoandget_marlin_quant_info, with backend selection viaget_moe_runner_backend().MarlinLoraRunnerCore(lora/lora_moe_runner_marlin.py) enables Marlin wNa16 GEMM for base experts when LoRA is active.Bug fix:
<=to<strict inequality (fixes IMA for Qwen3).Accuracy Tests
All 24 parametrized tests pass (
test_lora_moe_runner.py):test_lora_moe_runner_multi_expert(16 configs) — verifies LoRA delta matches between hook-based and baseline implementationstest_lora_moe_runner_marlin(8 configs) — verifies Marlin backend matches Triton backend for base expert computation with LoRAtest_marlin_lora_correctness.py— end-to-end correctness comparing Marlin vs Triton LoRA backends.Checklist
Review and Merge Process
/tag-and-rerun-ci,/tag-run-ci-label,/rerun-failed-ci