Skip to content

Commit 8da1cfb

Browse files
authored
[lora][moe] Decoupled LoRA MoE backend with Marlin support (#21858)
1 parent 78043d4 commit 8da1cfb

12 files changed

Lines changed: 1626 additions & 540 deletions

File tree

python/sglang/srt/layers/moe/moe_runner/base.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from abc import ABC, abstractmethod
44
from dataclasses import dataclass
5-
from typing import TYPE_CHECKING, Callable, Optional, Tuple, TypeGuard
5+
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, TypeGuard
66

77
import torch
88

@@ -82,7 +82,11 @@ def __init__(self, config: MoeRunnerConfig):
8282

8383
@abstractmethod
8484
def run(
85-
self, runner_input: RunnerInput, quant_info: MoeQuantInfo, running_state: dict
85+
self,
86+
runner_input: RunnerInput,
87+
quant_info: MoeQuantInfo,
88+
running_state: dict,
89+
hooks: Optional[Any] = None,
8690
) -> RunnerOutput:
8791
pass
8892

python/sglang/srt/layers/moe/moe_runner/deep_gemm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from dataclasses import dataclass
4-
from typing import TYPE_CHECKING, List, Optional
4+
from typing import TYPE_CHECKING, Any, List, Optional
55

66
import torch
77

@@ -120,6 +120,7 @@ def run(
120120
runner_input: DeepGemmRunnerInput,
121121
quant_info: DeepGemmMoeQuantInfo,
122122
running_state: dict,
123+
hooks: Optional[Any] = None,
123124
) -> DeepGemmRunnerOutput:
124125
if not runner_input.use_masked_gemm:
125126
hidden_states = self._run_contiguous_gemm(

python/sglang/srt/layers/moe/moe_runner/runner.py

Lines changed: 49 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import logging
44
import os
5-
from typing import TYPE_CHECKING, Optional
5+
from typing import TYPE_CHECKING, Any, Optional
66

77
from sglang.srt.layers.moe.moe_runner.base import (
88
FusedOpPool,
@@ -19,6 +19,8 @@
1919
from sglang.srt.layers.moe.moe_runner.base import MoeQuantInfo
2020
from sglang.srt.layers.moe.token_dispatcher.base import CombineInput, DispatchOutput
2121
from sglang.srt.layers.moe.utils import MoeRunnerBackend
22+
from sglang.srt.lora.lora_moe_runners import LoRAHooks
23+
2224

2325
logger = logging.getLogger(__name__)
2426

@@ -37,18 +39,18 @@ def __init__(
3739
self.fused_func = None
3840

3941
if runner_backend.is_triton():
40-
if lora_enabled:
41-
from sglang.srt.lora.lora_moe_runners import TritonRunnerCoreWithLoRA
42-
43-
self.runner_core = TritonRunnerCoreWithLoRA(config)
44-
else:
45-
self.runner_core = TritonRunnerCore(config)
42+
self.runner_core = TritonRunnerCore(config)
4643
elif runner_backend.is_triton_kernels():
4744
self.runner_core = TritonKernelsRunnerCore(config)
4845
elif runner_backend.is_deep_gemm():
4946
self.runner_core = DeepGemmRunnerCore(config)
5047
elif runner_backend.is_marlin():
51-
self.runner_core = None # Marlin only supports fused path
48+
if lora_enabled:
49+
from sglang.srt.lora.lora_moe_runner_marlin import MarlinLoraRunnerCore
50+
51+
self.runner_core = MarlinLoraRunnerCore(config)
52+
else:
53+
self.runner_core = None # Marlin only supports fused path
5254
elif (
5355
runner_backend.is_flashinfer_trtllm()
5456
or runner_backend.is_flashinfer_trtllm_routed()
@@ -94,6 +96,41 @@ def run(
9496
return self.fused_func(dispatch_output, quant_info, self.config)
9597

9698
assert self.runner_core is not None
99+
100+
def _maybe_build_lora_hooks(_runner_input: Any) -> LoRAHooks:
101+
if not self.lora_enabled or lora_info is None:
102+
return None
103+
104+
from sglang.srt.layers.moe.token_dispatcher.base import DispatchOutput
105+
from sglang.srt.lora.lora_moe_runners import build_lora_hooks
106+
107+
if isinstance(_runner_input, DispatchOutput):
108+
hidden_states, topk_ids = (
109+
_runner_input.hidden_states,
110+
_runner_input.topk_output.topk_ids,
111+
)
112+
elif hasattr(_runner_input, "topk_ids"):
113+
hidden_states, topk_ids = (
114+
_runner_input.hidden_states,
115+
_runner_input.topk_ids,
116+
)
117+
else:
118+
return None
119+
120+
return build_lora_hooks(
121+
hidden_states,
122+
lora_info,
123+
topk_ids,
124+
)
125+
126+
# Runners that handle dispatch_output directly (e.g., MarlinRunnerCore)
127+
# bypass the pre-permute step and do their own alignment internally.
128+
if hasattr(self.runner_core, "run_from_dispatch"):
129+
hooks = _maybe_build_lora_hooks(dispatch_output)
130+
return self.runner_core.run_from_dispatch(
131+
dispatch_output, quant_info, self.config, hooks=hooks
132+
)
133+
97134
dispatch_format = dispatch_output.format.value
98135
runner_format = self.runner_core.runner_backend.value
99136
self.pre_permute_func = PermuteMethodPool.get_pre_permute(
@@ -110,16 +147,11 @@ def run(
110147
dispatch_output, quant_info, self.config, running_state
111148
)
112149

113-
# Pass lora_info to runner_core if LoRA is enabled
114-
if self.lora_enabled:
115-
runner_output = self.runner_core.run(
116-
runner_input, quant_info, running_state, lora_info
117-
)
118-
else:
119-
runner_output = self.runner_core.run(
120-
runner_input, quant_info, running_state
121-
)
150+
hooks = _maybe_build_lora_hooks(runner_input)
122151

152+
runner_output = self.runner_core.run(
153+
runner_input, quant_info, running_state, hooks=hooks
154+
)
123155
runner_format = self.runner_core.runner_backend.value
124156
combine_format = dispatch_output.format.value
125157
self.post_permute_func = PermuteMethodPool.get_post_permute(

python/sglang/srt/layers/moe/moe_runner/triton.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import functools
44
import os
55
from dataclasses import dataclass
6-
from typing import TYPE_CHECKING, List, Optional
6+
from typing import TYPE_CHECKING, Any, List, Optional
77

88
import torch
99
import triton.language as tl
@@ -124,6 +124,7 @@ def run(
124124
runner_input: TritonRunnerInput,
125125
quant_info: TritonMoeQuantInfo,
126126
running_state: dict,
127+
hooks: Optional[Any] = None,
127128
) -> TritonRunnerOutput:
128129

129130
# TODO: move these functions to the triton runner
@@ -206,6 +207,11 @@ def run(
206207
block_shape=block_shape,
207208
)
208209

210+
if hooks and hooks.after_gate_up:
211+
hooks.after_gate_up(
212+
hidden_states, intermediate_cache1, topk_weights, topk_ids
213+
)
214+
209215
intermediate_cache2 = torch.empty(
210216
(M * topk_ids.shape[1], N // 2),
211217
device=hidden_states.device,
@@ -258,13 +264,16 @@ def run(
258264
else:
259265
out_hidden_states = torch.empty_like(hidden_states)
260266

267+
# When LoRA hooks are present, always write to intermediate_cache3
268+
# so the hook can modify it before reduction.
269+
_use_intermediate = not no_combine and (topk_ids.shape[1] != 1 or hooks)
261270
invoke_fused_moe_kernel(
262271
intermediate_cache2,
263272
w2,
264273
b2,
265274
(
266275
intermediate_cache3
267-
if not no_combine and topk_ids.shape[1] != 1
276+
if _use_intermediate
268277
else out_hidden_states.unsqueeze(0)
269278
),
270279
a2_scale,
@@ -287,14 +296,23 @@ def run(
287296
block_shape=block_shape,
288297
)
289298

299+
if hooks and hooks.after_down:
300+
hooks.after_down(
301+
intermediate_cache2, intermediate_cache3, topk_weights, topk_ids
302+
)
303+
290304
if routed_scaling_factor is None:
291305
routed_scaling_factor = 1.0
292306

293307
if no_combine:
294308
pass
295309
elif _is_cuda:
296-
if topk_ids.shape[1] == 1 and routed_scaling_factor == 1.0:
297-
pass # we write directly into out_hidden_states
310+
if (
311+
topk_ids.shape[1] == 1
312+
and routed_scaling_factor == 1.0
313+
and not _use_intermediate
314+
):
315+
pass # we wrote directly into out_hidden_states
298316
elif topk_ids.shape[1] == 2 and routed_scaling_factor == 1.0:
299317
torch.add(
300318
intermediate_cache3[:, 0],

python/sglang/srt/layers/moe/moe_runner/triton_kernels.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
from dataclasses import dataclass
6-
from typing import TYPE_CHECKING, Optional
6+
from typing import TYPE_CHECKING, Any, Optional
77

88
import torch
99

@@ -84,6 +84,7 @@ def run(
8484
runner_input: TritonKernelsRunnerInput,
8585
quant_info: TritonKernelsQuantInfo,
8686
running_state: dict,
87+
hooks: Optional[Any] = None,
8788
) -> TritonKernelsRunnerOutput:
8889
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
8990
triton_kernel_fused_experts,

python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -682,16 +682,13 @@ def get_moe_scheme(
682682
logger.info_once("Using CompressedTensorsWNA16TritonMoE (ROCm)")
683683
return CompressedTensorsWNA16TritonMoE(self)
684684
else:
685-
from sglang.srt.server_args import get_global_server_args
686-
687-
server_args = get_global_server_args()
688-
if server_args and server_args.enable_lora:
685+
moe_backend = get_moe_runner_backend()
686+
if moe_backend.is_triton():
689687
logger.info_once(
690-
"Using CompressedTensorsWNA16TritonMoEMethod "
691-
"(LoRA requires triton-compatible MoE weights)"
688+
"Using CompressedTensorsWNA16TritonMoE "
689+
"(moe_runner_backend=triton)"
692690
)
693691
return CompressedTensorsWNA16TritonMoE(self)
694-
695692
logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
696693
return CompressedTensorsWNA16MoE(self)
697694
else:
@@ -1010,6 +1007,9 @@ def create_moe_runner(
10101007
def get_triton_quant_info(self, layer: torch.nn.Module):
10111008
return layer.scheme.get_triton_quant_info(layer)
10121009

1010+
def get_marlin_quant_info(self, layer: torch.nn.Module):
1011+
return layer.scheme.get_marlin_quant_info(layer)
1012+
10131013
def apply(
10141014
self,
10151015
layer: torch.nn.Module,

python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16_moe.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,23 @@ def create_moe_runner(
354354
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
355355
):
356356
self.moe_runner_config = moe_runner_config
357+
self.runner = MoeRunner(MoeRunnerBackend.MARLIN, moe_runner_config)
358+
359+
def get_marlin_quant_info(self, layer):
360+
from sglang.srt.layers.moe.moe_runner.marlin import MarlinMoeQuantInfo
361+
362+
return MarlinMoeQuantInfo(
363+
w13_qweight=layer.w13_weight_packed,
364+
w2_qweight=layer.w2_weight_packed,
365+
w13_scales=layer.w13_weight_scale,
366+
w2_scales=layer.w2_weight_scale,
367+
w13_g_idx_sort_indices=getattr(layer, "w13_g_idx_sort_indices", None),
368+
w2_g_idx_sort_indices=getattr(layer, "w2_g_idx_sort_indices", None),
369+
weight_bits=self.num_bits,
370+
w13_g_idx=getattr(layer, "w13_weight_g_idx", None),
371+
w2_g_idx=getattr(layer, "w2_weight_g_idx", None),
372+
is_k_full=self.is_k_full,
373+
)
357374

358375
def apply_weights(
359376
self,

python/sglang/srt/lora/layers.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -808,13 +808,20 @@ def __init__(
808808
getattr(base_layer.moe_runner_config, "gemm1_alpha", None) is not None
809809
)
810810

811-
# initialize triton_lora moe runner for batches with lora enabled
811+
# Initialize triton_lora moe runner for batches with lora enabled
812812
from sglang.srt.layers.moe import MoeRunnerBackend
813813
from sglang.srt.layers.moe.moe_runner.runner import MoeRunner
814-
815-
qm = base_layer.quant_method
816-
if hasattr(qm, "runner") and qm.runner is not None:
817-
runner_backend = qm.runner.runner_backend
814+
from sglang.srt.layers.moe.utils import get_moe_runner_backend
815+
816+
# Determine runner backend: prefer server arg, fall back to quant method's runner
817+
global_backend = get_moe_runner_backend()
818+
if not global_backend.is_auto():
819+
runner_backend = global_backend
820+
elif (
821+
hasattr(base_layer.quant_method, "runner")
822+
and base_layer.quant_method.runner is not None
823+
):
824+
runner_backend = base_layer.quant_method.runner.runner_backend
818825
else:
819826
runner_backend = MoeRunnerBackend.TRITON
820827

@@ -824,8 +831,25 @@ def __init__(
824831
lora_enabled=True,
825832
)
826833

827-
# Pre-compute quant info for efficiency (weights don't change during inference)
828-
self._quant_info = base_layer.quant_method.get_triton_quant_info(base_layer)
834+
if runner_backend.is_marlin():
835+
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
836+
CompressedTensorsFusedMoEMethod,
837+
)
838+
839+
assert isinstance(
840+
base_layer.quant_method, CompressedTensorsFusedMoEMethod
841+
), (
842+
f"Marlin MoE backend requires CompressedTensorsFusedMoEMethod, "
843+
f"got {type(base_layer.quant_method).__name__}"
844+
)
845+
self._quant_info = base_layer.quant_method.get_marlin_quant_info(base_layer)
846+
elif runner_backend.is_triton():
847+
assert base_layer.quant_method is not None, "Quant method must be set"
848+
self._quant_info = base_layer.quant_method.get_triton_quant_info(base_layer)
849+
else:
850+
raise NotImplementedError(
851+
f"LoRA MoE not supported for backend {runner_backend}"
852+
)
829853

830854
def set_lora_info(
831855
self,
@@ -876,7 +900,6 @@ def _get_lora_info(self):
876900
num_experts=self.base_layer.num_experts,
877901
experts_shared_outer_loras=self.experts_shared_outer_loras,
878902
cg_buffers=cg_buffers,
879-
has_active_lora=batch_info.has_active_lora,
880903
tp_size=self.tp_size,
881904
tp_rank=self.tp_rank,
882905
hidden_size=getattr(self.base_layer, "hidden_size", 0),

0 commit comments

Comments
 (0)