22
33import logging
44import os
5- from typing import TYPE_CHECKING , Optional
5+ from typing import TYPE_CHECKING , Any , Optional
66
77from sglang .srt .layers .moe .moe_runner .base import (
88 FusedOpPool ,
1919 from sglang .srt .layers .moe .moe_runner .base import MoeQuantInfo
2020 from sglang .srt .layers .moe .token_dispatcher .base import CombineInput , DispatchOutput
2121 from sglang .srt .layers .moe .utils import MoeRunnerBackend
22+ from sglang .srt .lora .lora_moe_runners import LoRAHooks
23+
2224
2325logger = logging .getLogger (__name__ )
2426
@@ -37,18 +39,18 @@ def __init__(
3739 self .fused_func = None
3840
3941 if runner_backend .is_triton ():
40- if lora_enabled :
41- from sglang .srt .lora .lora_moe_runners import TritonRunnerCoreWithLoRA
42-
43- self .runner_core = TritonRunnerCoreWithLoRA (config )
44- else :
45- self .runner_core = TritonRunnerCore (config )
42+ self .runner_core = TritonRunnerCore (config )
4643 elif runner_backend .is_triton_kernels ():
4744 self .runner_core = TritonKernelsRunnerCore (config )
4845 elif runner_backend .is_deep_gemm ():
4946 self .runner_core = DeepGemmRunnerCore (config )
5047 elif runner_backend .is_marlin ():
51- self .runner_core = None # Marlin only supports fused path
48+ if lora_enabled :
49+ from sglang .srt .lora .lora_moe_runner_marlin import MarlinLoraRunnerCore
50+
51+ self .runner_core = MarlinLoraRunnerCore (config )
52+ else :
53+ self .runner_core = None # Marlin only supports fused path
5254 elif (
5355 runner_backend .is_flashinfer_trtllm ()
5456 or runner_backend .is_flashinfer_trtllm_routed ()
@@ -94,6 +96,41 @@ def run(
9496 return self .fused_func (dispatch_output , quant_info , self .config )
9597
9698 assert self .runner_core is not None
99+
100+ def _maybe_build_lora_hooks (_runner_input : Any ) -> LoRAHooks :
101+ if not self .lora_enabled or lora_info is None :
102+ return None
103+
104+ from sglang .srt .layers .moe .token_dispatcher .base import DispatchOutput
105+ from sglang .srt .lora .lora_moe_runners import build_lora_hooks
106+
107+ if isinstance (_runner_input , DispatchOutput ):
108+ hidden_states , topk_ids = (
109+ _runner_input .hidden_states ,
110+ _runner_input .topk_output .topk_ids ,
111+ )
112+ elif hasattr (_runner_input , "topk_ids" ):
113+ hidden_states , topk_ids = (
114+ _runner_input .hidden_states ,
115+ _runner_input .topk_ids ,
116+ )
117+ else :
118+ return None
119+
120+ return build_lora_hooks (
121+ hidden_states ,
122+ lora_info ,
123+ topk_ids ,
124+ )
125+
126+ # Runners that handle dispatch_output directly (e.g., MarlinRunnerCore)
127+ # bypass the pre-permute step and do their own alignment internally.
128+ if hasattr (self .runner_core , "run_from_dispatch" ):
129+ hooks = _maybe_build_lora_hooks (dispatch_output )
130+ return self .runner_core .run_from_dispatch (
131+ dispatch_output , quant_info , self .config , hooks = hooks
132+ )
133+
97134 dispatch_format = dispatch_output .format .value
98135 runner_format = self .runner_core .runner_backend .value
99136 self .pre_permute_func = PermuteMethodPool .get_pre_permute (
@@ -110,16 +147,11 @@ def run(
110147 dispatch_output , quant_info , self .config , running_state
111148 )
112149
113- # Pass lora_info to runner_core if LoRA is enabled
114- if self .lora_enabled :
115- runner_output = self .runner_core .run (
116- runner_input , quant_info , running_state , lora_info
117- )
118- else :
119- runner_output = self .runner_core .run (
120- runner_input , quant_info , running_state
121- )
150+ hooks = _maybe_build_lora_hooks (runner_input )
122151
152+ runner_output = self .runner_core .run (
153+ runner_input , quant_info , running_state , hooks = hooks
154+ )
123155 runner_format = self .runner_core .runner_backend .value
124156 combine_format = dispatch_output .format .value
125157 self .post_permute_func = PermuteMethodPool .get_post_permute (
0 commit comments