Skip to content

Commit 6ed9c53

Browse files
committed
Add raw logits topk
1 parent 863f9d5 commit 6ed9c53

1 file changed

Lines changed: 42 additions & 9 deletions

File tree

  • python/sglang/srt/layers/moe

python/sglang/srt/layers/moe/topk.py

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
)
3030

3131
import torch
32+
import torch.nn.functional as F
3233

3334
try:
3435
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
@@ -443,6 +444,25 @@ def scoring_func_impl(gating_output: torch.Tensor) -> torch.Tensor:
443444
return topk_weights, topk_ids
444445

445446

447+
def fused_topk_softmax_torch_raw_logits(
448+
hidden_states: torch.Tensor,
449+
gating_output: torch.Tensor,
450+
topk: int,
451+
renormalize: bool,
452+
):
453+
assert (
454+
hidden_states.shape[0] == gating_output.shape[0]
455+
), f"Number of tokens mismatch, {hidden_states.shape=} vs {gating_output.shape=}"
456+
457+
_, topk_ids = torch.topk(gating_output, k=topk, dim=-1, sorted=False)
458+
logits = gating_output.float()
459+
topk_weights = logits.gather(1, topk_ids)
460+
if renormalize:
461+
topk_weights = F.softmax(topk_weights, dim=-1, dtype=torch.float32)
462+
463+
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
464+
465+
446466
def fused_topk_cpu(
447467
hidden_states: torch.Tensor,
448468
gating_output: torch.Tensor,
@@ -1030,15 +1050,28 @@ def select_experts(
10301050
)
10311051
elif custom_routing_function is None:
10321052
assert not apply_routed_scaling_factor_on_output, "Not implemented"
1033-
# Qwen3MOE uses fused_topk
1034-
topk_weights, topk_ids = fused_topk(
1035-
hidden_states=hidden_states,
1036-
gating_output=router_logits,
1037-
topk=num_routed_topk if _use_aiter else top_k,
1038-
renormalize=renormalize,
1039-
correction_bias=correction_bias,
1040-
scoring_func=scoring_func,
1041-
)
1053+
if (
1054+
get_moe_runner_backend().is_flashinfer_trtllm_routed()
1055+
and scoring_func == "softmax"
1056+
and correction_bias is None
1057+
):
1058+
# flashinfer_trtllm_routed uses raw-logits topk
1059+
topk_weights, topk_ids = fused_topk_softmax_torch_raw_logits(
1060+
hidden_states=hidden_states,
1061+
gating_output=router_logits,
1062+
topk=num_routed_topk if _use_aiter else top_k,
1063+
renormalize=renormalize,
1064+
)
1065+
else:
1066+
# Qwen3MOE uses fused_topk
1067+
topk_weights, topk_ids = fused_topk(
1068+
hidden_states=hidden_states,
1069+
gating_output=router_logits,
1070+
topk=num_routed_topk if _use_aiter else top_k,
1071+
renormalize=renormalize,
1072+
correction_bias=correction_bias,
1073+
scoring_func=scoring_func,
1074+
)
10421075
else:
10431076
assert (
10441077
num_token_non_padded is None

0 commit comments

Comments
 (0)