-
Notifications
You must be signed in to change notification settings - Fork 5.8k
Expand file tree
/
Copy pathtopk.py
More file actions
1429 lines (1262 loc) · 50.6 KB
/
topk.py
File metadata and controls
1429 lines (1262 loc) · 50.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import annotations
import logging
import math
from dataclasses import dataclass
from enum import IntEnum, auto
from typing import (
TYPE_CHECKING,
Callable,
NamedTuple,
Optional,
Protocol,
TypeGuard,
runtime_checkable,
)
import torch
import torch.nn.functional as F
try:
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
except ImportError:
pass
from sglang.jit_kernel.deepseek_v4 import mask_topk_ids
from sglang.srt.distributed import (
get_moe_expert_parallel_rank,
get_moe_expert_parallel_world_size,
get_tp_group,
)
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.environ import envs
from sglang.srt.eplb import expert_location_dispatch
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location_dispatch import (
ExpertLocationDispatchInfo,
topk_ids_logical_to_physical,
)
from sglang.srt.layers.dp_attention import is_allocation_symmetric
from sglang.srt.layers.moe import get_moe_runner_backend
from sglang.srt.layers.moe.utils import is_deepep_class_backend
from sglang.srt.layers.utils import MultiPlatformOp
from sglang.srt.state_capturer.routed_experts import get_global_experts_capturer
from sglang.srt.utils import (
cpu_has_amx_support,
get_bool_env_var,
get_compiler_backend,
is_cpu,
is_cuda,
is_hip,
is_musa,
is_npu,
is_xpu,
)
from sglang.srt.utils.patch_torch import register_fake_if_exists
if TYPE_CHECKING:
from sglang.srt.layers.quantization import QuantizationConfig
logger = logging.getLogger(__name__)
_is_cuda = is_cuda()
_is_hip = is_hip()
_is_cpu = is_cpu()
_is_cpu_amx_available = cpu_has_amx_support()
_is_xpu = is_xpu()
_is_npu = is_npu()
_is_xpu = is_xpu()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
_is_musa = is_musa()
if _is_cuda:
from sgl_kernel import moe_fused_gate
try:
from flashinfer.fused_moe import fused_topk_deepseek as _fused_topk_deepseek
from sglang.srt.utils.custom_op import register_custom_op
@register_custom_op(
op_name="fused_topk_deepseek",
mutates_args=["topk_weights", "topk_ids"],
)
def fused_topk_deepseek(
gating_output: torch.Tensor,
correction_bias: torch.Tensor,
num_expert_group: int,
topk_group: int,
topk: int,
scaling_factor: float,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
renormalize: bool,
) -> None:
_fused_topk_deepseek(
gating_output,
correction_bias,
num_expert_group,
topk_group,
topk,
scaling_factor,
topk_weights,
topk_ids,
renormalize,
)
except ImportError:
fused_topk_deepseek = None
try:
from sgl_kernel import kimi_k2_moe_fused_gate
except ImportError as e:
pass
if _is_cuda or _is_hip or _is_xpu:
from sgl_kernel import topk_softmax
try:
from sgl_kernel import topk_sigmoid
except ImportError:
pass
if _use_aiter:
try:
from aiter import biased_grouped_topk as aiter_biased_grouped_topk
from aiter.fused_moe import fused_topk as aiter_fused_topk
except ImportError:
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
if _is_musa:
try:
from mate import moe_fused_gate
except ImportError as e:
raise ImportError("mate is required for the biased grouped topk.")
from sglang.srt.hardware_backend.musa.kernels.topk import topk_sigmoid, topk_softmax
# -------------------------------- TopKConfig ---------------------------------------
@dataclass
class TopKConfig:
top_k: int
use_grouped_topk: bool = False
topk_group: Optional[int] = None
num_expert_group: Optional[int] = None
renormalize: bool = True
num_fused_shared_experts: int = 0
custom_routing_function: Optional[Callable] = None
correction_bias: Optional[torch.Tensor] = None
torch_native: bool = False
routed_scaling_factor: Optional[float] = None
apply_routed_scaling_factor_on_output: bool = False
fused_shared_experts_scaling_factor: Optional[float] = None
output_format: Optional[TopKOutputFormat] = None
scoring_func: str = "softmax"
# -------------------------------- TopKOutput ---------------------------------------
class TopKOutputChecker:
@staticmethod
def format_is_standard(topk_output: TopKOutput) -> TypeGuard[StandardTopKOutput]:
return isinstance(topk_output, StandardTopKOutput)
@staticmethod
def format_is_triton_kernels(
topk_output: TopKOutput,
) -> TypeGuard[TritonKernelTopKOutput]:
return isinstance(topk_output, TritonKernelTopKOutput)
@staticmethod
def format_is_bypassed(topk_output: TopKOutput) -> TypeGuard[BypassedTopKOutput]:
return isinstance(topk_output, BypassedTopKOutput)
class TopKOutputFormat(IntEnum):
STANDARD = auto()
TRITON_KERNEL = auto()
BYPASSED = auto()
@runtime_checkable
class TopKOutput(Protocol):
"""Protocol for top-k outputs in different formats."""
@property
def format(self) -> TopKOutputFormat:
"""The format of the output."""
...
class StandardTopKOutput(NamedTuple):
"""Standard top-k output format."""
topk_weights: torch.Tensor
topk_ids: torch.Tensor
router_logits: torch.Tensor
@property
def format(self) -> TopKOutputFormat:
return TopKOutputFormat.STANDARD
class TritonKernelTopKOutput(NamedTuple):
"""Triton kernel top-k output format."""
routing_data: RoutingData
gather_indx: GatherIndx
scatter_indx: ScatterIndx
@property
def format(self) -> TopKOutputFormat:
return TopKOutputFormat.TRITON_KERNEL
class BypassedTopKOutput(NamedTuple):
"""Bypassed top-k output format."""
hidden_states: torch.Tensor
router_logits: torch.Tensor
topk_config: TopKConfig
num_token_non_padded: Optional[torch.Tensor] = None
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None
@property
def format(self) -> TopKOutputFormat:
return TopKOutputFormat.BYPASSED
# -------------------------------- TopK ---------------------------------------
class TopK(MultiPlatformOp):
"""
Parameters:
--top_k: The all number of top experts selected per token, including the fused shared expert(s).
--num_fused_shared_experts: num of shared experts, can be activate both in TP or EP mode.
--routed_scaling_factor: the scaling factor for routed experts in topk_weights.
--fused_shared_experts_scaling_factor: scaling factor for fused shared experts on AMD-platform.
"""
def __init__(
self,
top_k: int,
*,
layer_id: Optional[int] = None,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
renormalize: bool = True,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
correction_bias: Optional[torch.Tensor] = None,
quant_config: Optional[QuantizationConfig] = None,
routed_scaling_factor: Optional[float] = None,
apply_routed_scaling_factor_on_output: Optional[bool] = False,
output_format: Optional[TopKOutputFormat] = None,
fused_shared_experts_scaling_factor: Optional[float] = None,
is_fp4_experts: bool = False,
):
# NOTE: scoring_func is not used for now, but we keep it for future use
# see https://github.com/sgl-project/sglang/pull/4505 for more details
super().__init__()
if use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
self.layer_id = layer_id
# flashinfer_mxfp4 backend only: True -> STANDARD (Mxfp4FlashinferTrtllmMoEMethod
# consumes), False -> BYPASSED (flashinfer's own mxfp4 kernel). No-op otherwise.
self.is_fp4_experts = is_fp4_experts
self.topk_config = TopKConfig(
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
fused_shared_experts_scaling_factor=fused_shared_experts_scaling_factor,
output_format=output_format,
scoring_func=scoring_func,
)
def forward_native(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
*,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
) -> TopKOutput:
self.topk_config.torch_native = True
return select_experts(
hidden_states=hidden_states,
layer_id=self.layer_id,
router_logits=router_logits,
topk_config=self.topk_config,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
)
def forward_cuda(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
*,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
) -> TopKOutput:
if self.topk_config.output_format is not None:
output_format = self.topk_config.output_format
elif get_moe_runner_backend().is_triton_kernels():
output_format = TopKOutputFormat.TRITON_KERNEL
elif get_moe_runner_backend().is_flashinfer_trtllm() or (
get_moe_runner_backend().is_flashinfer_mxfp4() and not self.is_fp4_experts
):
output_format = TopKOutputFormat.BYPASSED
else:
output_format = TopKOutputFormat.STANDARD
if output_format == TopKOutputFormat.TRITON_KERNEL:
# renormalize=True is equivalent to sm_first=False
routing_data, gather_idx, scatter_idx = routing(
router_logits,
self.topk_config.top_k,
sm_first=not self.topk_config.renormalize,
)
return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx)
elif output_format == TopKOutputFormat.BYPASSED:
return BypassedTopKOutput(
hidden_states=hidden_states,
router_logits=router_logits,
topk_config=self.topk_config,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
)
else:
self.topk_config.torch_native = False
with use_symmetric_memory(
get_tp_group(), disabled=not is_allocation_symmetric()
):
topk_output = select_experts(
hidden_states=hidden_states,
layer_id=self.layer_id,
router_logits=router_logits,
topk_config=self.topk_config,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
)
return topk_output
def forward_cpu(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
*,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
) -> TopKOutput:
return select_experts(
hidden_states=hidden_states,
layer_id=self.layer_id,
router_logits=router_logits,
topk_config=self.topk_config,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
)
def forward_npu(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
*,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
) -> TopKOutput:
from sglang.srt.hardware_backend.npu.moe.topk import fused_topk_npu
return fused_topk_npu(
hidden_states=hidden_states,
router_logits=router_logits,
topk_config=self.topk_config,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
layer_id=self.layer_id,
)
def empty_topk_output(self, device: torch.device) -> TopKOutput:
topk = self.topk_config.top_k - self.topk_config.num_fused_shared_experts
with use_symmetric_memory(
get_tp_group(), disabled=not is_allocation_symmetric()
):
topk_weights = torch.empty((0, topk), dtype=torch.float32, device=device)
topk_ids = torch.full((0, topk), -1, dtype=torch.int32, device=device)
# FIXME: router_logits should be of size (0, num_experts)
router_logits = torch.empty((0, topk), dtype=torch.float32, device=device)
return StandardTopKOutput(topk_weights, topk_ids, router_logits)
# ------------------------------- TopK implementation -------------------------------------
def fused_topk_torch_native(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
correction_bias: torch.Tensor = None,
scoring_func: str = "softmax",
):
def scoring_func_impl(gating_output: torch.Tensor) -> torch.Tensor:
if scoring_func == "softmax":
return gating_output.softmax(dim=-1)
elif scoring_func == "sigmoid":
return gating_output.sigmoid()
else:
raise ValueError(f"Invalid scoring function: {scoring_func}")
if correction_bias is not None:
n_routed_experts = gating_output.shape[-1]
scores = scoring_func_impl(gating_output)
scores_for_choice = scores.view(
-1, n_routed_experts
) + correction_bias.unsqueeze(0)
topk_ids = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=False)[1]
topk_weights = scores.gather(1, topk_ids)
else:
assert (
hidden_states.shape[0] == gating_output.shape[0]
), f"Number of tokens mismatch, {hidden_states.shape=} vs {gating_output.shape=}"
M, _ = hidden_states.shape
topk_weights = torch.empty(
M, topk, dtype=torch.float32, device=hidden_states.device
)
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
topk_weights = scoring_func_impl(gating_output.float())
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
def fused_topk_softmax_torch_raw_logits(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
):
assert (
hidden_states.shape[0] == gating_output.shape[0]
), f"Number of tokens mismatch, {hidden_states.shape=} vs {gating_output.shape=}"
_, topk_ids = torch.topk(gating_output, k=topk, dim=-1, sorted=False)
logits = gating_output.float()
topk_weights = logits.gather(1, topk_ids)
if renormalize:
topk_weights = F.softmax(topk_weights, dim=-1, dtype=torch.float32)
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
def fused_topk_cpu(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
correction_bias: torch.Tensor = None,
scoring_func: str = "softmax",
):
topk_weights, topk_ids = torch.ops.sgl_kernel.topk_softmax_cpu(
hidden_states=hidden_states,
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
)
return topk_weights, topk_ids
def apply_topk_weights_cpu(need_apply, topk_weights, inputs):
if not need_apply:
return inputs, topk_weights
# TODO: fuse below processing in fused_experts_cpu kernel
inputs = inputs * topk_weights.to(inputs.dtype)
topk_weights = torch.ones_like(
topk_weights, dtype=torch.float32
) # clear topk_weights as already applied
return inputs, topk_weights
def fused_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
correction_bias: Optional[torch.Tensor] = None,
scoring_func: str = "softmax",
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
M, _ = hidden_states.shape
topk_weights = torch.empty(
M, topk, dtype=torch.float32, device=hidden_states.device
)
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
if scoring_func == "softmax":
if _use_aiter:
# Use fused_topk instead of topk_softmax to auto dispatch to the correct kernel
topk_weights, topk_ids = aiter_fused_topk(
hidden_states,
gating_output,
topk,
renormalize,
topk_ids=topk_ids,
topk_weights=topk_weights,
)
else:
topk_softmax(
topk_weights,
topk_ids,
gating_output,
renormalize,
)
elif scoring_func == "sigmoid":
if _use_aiter and correction_bias is not None:
aiter_biased_grouped_topk(
gating_output,
correction_bias.to(dtype=gating_output.dtype),
topk_weights,
topk_ids,
num_expert_group=1,
topk_group=1,
need_renorm=renormalize,
)
else:
topk_sigmoid(
topk_weights,
topk_ids,
gating_output,
renormalize,
correction_bias,
)
else:
raise ValueError(f"Invalid scoring function: {scoring_func}")
return topk_weights, topk_ids
# This is used by the Deepseek V2/V3/R1 series models
@torch.compile(dynamic=True, backend=get_compiler_backend(), disable=_is_npu)
def grouped_topk_gpu(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
routed_scaling_factor: Optional[float] = None,
apply_routed_scaling_factor_on_output: Optional[bool] = False,
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
scores = torch.softmax(gating_output, dim=-1)
num_token = scores.shape[0]
num_experts = scores.shape[1]
group_scores = (
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
) # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
1
] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = (
group_mask.unsqueeze(-1)
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
.reshape(num_token, -1)
) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
topk_weights, topk_ids = torch.topk(
tmp_scores,
k=topk,
dim=-1,
sorted=(True if num_fused_shared_experts > 0 else False),
)
if num_fused_shared_experts:
topk_ids[:, -1] = torch.randint(
low=num_experts,
high=num_experts + num_fused_shared_experts,
size=(topk_ids.size(0),),
dtype=topk_ids.dtype,
device=topk_ids.device,
)
if routed_scaling_factor is not None:
topk_weights[:, -1] = (
topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor
)
if renormalize:
topk_weights_sum = (
topk_weights.sum(dim=-1, keepdim=True)
if num_fused_shared_experts == 0
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
)
topk_weights = topk_weights / topk_weights_sum
if apply_routed_scaling_factor_on_output:
topk_weights *= routed_scaling_factor
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
return topk_weights, topk_ids
def grouped_topk_cpu(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
routed_scaling_factor: Optional[float] = None,
apply_routed_scaling_factor_on_output: Optional[bool] = False,
):
assert not apply_routed_scaling_factor_on_output
return torch.ops.sgl_kernel.grouped_topk_cpu(
hidden_states,
gating_output,
topk,
renormalize,
num_expert_group,
topk_group,
num_fused_shared_experts,
routed_scaling_factor,
# num_token_non_padded must be None since it is not supported in kernel
num_token_non_padded=None,
)
@torch.compile(dynamic=True, backend=get_compiler_backend(), disable=_is_npu)
def kimi_k2_biased_topk_impl(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
correction_bias: torch.Tensor,
topk: int,
renormalize: bool,
routed_scaling_factor: Optional[float] = None,
apply_routed_scaling_factor_on_output: Optional[bool] = False,
):
"""
Optimized version for num_expert_group=1 case (e.g., Kimi K2 with 384 experts).
Simplifies the grouped topk logic by removing unnecessary group masking operations.
Note: This function assumes num_fused_shared_experts=0.
"""
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
scores = gating_output.sigmoid()
num_token = scores.shape[0]
# When num_expert_group=1, no need for group masking
# Directly compute scores with correction bias
tmp_scores = scores.view(num_token, -1) + correction_bias.unsqueeze(0)
# Directly select topk experts (no need to sort since num_fused_shared_experts=0)
_, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
topk_weights = scores.gather(1, topk_ids)
if renormalize:
topk_weights_sum = topk_weights.sum(dim=-1, keepdim=True)
topk_weights = topk_weights / topk_weights_sum
if apply_routed_scaling_factor_on_output:
topk_weights *= routed_scaling_factor
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
return topk_weights, topk_ids
@torch.compile(dynamic=True, backend=get_compiler_backend(), disable=_is_npu)
def biased_topk_impl(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
correction_bias: torch.Tensor,
topk: int,
renormalize: bool,
scoring_func: str = "sigmoid",
num_fused_shared_experts: int = 0,
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
apply_routed_scaling_factor_on_output: Optional[bool] = False,
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
if scoring_func == "sigmoid":
scores = gating_output.sigmoid()
elif scoring_func == "sqrtsoftplus":
scores = torch.nn.functional.softplus(gating_output).sqrt()
num_token = scores.shape[0]
num_experts = scores.shape[1]
scores_for_choice = scores.view(num_token, -1) + correction_bias.unsqueeze(0)
_, topk_ids = torch.topk(
scores_for_choice,
k=topk,
dim=-1,
sorted=(True if num_fused_shared_experts > 0 else False),
)
topk_weights = scores.gather(1, topk_ids)
if num_fused_shared_experts:
topk_ids[:, -1] = torch.randint(
low=num_experts,
high=num_experts + num_fused_shared_experts,
size=(topk_ids.size(0),),
dtype=topk_ids.dtype,
device=topk_ids.device,
)
if routed_scaling_factor is not None:
topk_weights[:, -1] = (
topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor
)
if renormalize:
topk_weights_sum = (
topk_weights.sum(dim=-1, keepdim=True)
if num_fused_shared_experts == 0
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
)
topk_weights = topk_weights / topk_weights_sum
if apply_routed_scaling_factor_on_output:
topk_weights *= routed_scaling_factor
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
return topk_weights, topk_ids
def biased_topk_jit_kernel_impl(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
correction_bias: torch.Tensor,
topk: int,
renormalize: bool,
scoring_func: str = "sigmoid",
num_fused_shared_experts: int = 0,
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
apply_routed_scaling_factor_on_output: Optional[bool] = False,
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
from sglang.jit_kernel.moe_fused_gate import moe_fused_gate
topk_weights, topk_ids = moe_fused_gate(
gating_output,
correction_bias,
topk=topk,
scoring_func=scoring_func,
num_fused_shared_experts=num_fused_shared_experts,
renormalize=renormalize,
routed_scaling_factor=routed_scaling_factor,
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
)
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
return topk_weights, topk_ids
@torch.compile(dynamic=True, backend=get_compiler_backend(), disable=_is_npu)
def biased_grouped_topk_impl(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
correction_bias: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
routed_scaling_factor: Optional[float] = None,
apply_routed_scaling_factor_on_output: Optional[bool] = False,
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
scores = gating_output.sigmoid()
num_token = scores.shape[0]
num_experts = scores.shape[1]
scores_for_choice = scores.view(num_token, -1) + correction_bias.unsqueeze(0)
group_scores = (
scores_for_choice.view(num_token, num_expert_group, -1)
.topk(2, dim=-1)[0]
.sum(dim=-1)
) # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
1
] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = (
group_mask.unsqueeze(-1)
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
.reshape(num_token, -1)
) # [n, e]
tmp_scores = scores_for_choice.masked_fill(
~score_mask.bool(), float("-inf")
) # [n, e]
_, topk_ids = torch.topk(
tmp_scores,
k=topk,
dim=-1,
sorted=(True if num_fused_shared_experts > 0 else False),
)
topk_weights = scores.gather(1, topk_ids)
if num_fused_shared_experts:
topk_ids[:, -1] = torch.randint(
low=num_experts,
high=num_experts + num_fused_shared_experts,
size=(topk_ids.size(0),),
dtype=topk_ids.dtype,
device=topk_ids.device,
)
if routed_scaling_factor is not None:
topk_weights[:, -1] = (
topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor
)
if renormalize:
topk_weights_sum = (
topk_weights.sum(dim=-1, keepdim=True)
if num_fused_shared_experts == 0
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
)
topk_weights = topk_weights / topk_weights_sum
if apply_routed_scaling_factor_on_output:
topk_weights *= routed_scaling_factor
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
return topk_weights, topk_ids
def is_power_of_two(n):
return n > 0 and math.log2(n).is_integer()
def _mask_topk_ids_padded_region(
topk_ids: torch.Tensor,
num_token_non_padded: Optional[torch.Tensor] = None,
) -> None:
if num_token_non_padded is None:
return
# TODO: let the kernel support other dtypes
if _is_cuda and topk_ids.dtype == torch.int32:
mask_topk_ids(topk_ids, num_token_non_padded)
else:
indices = torch.arange(0, topk_ids.shape[0], device=topk_ids.device)
topk_ids[indices >= num_token_non_padded, :] = -1
@torch.compile(dynamic=True, backend=get_compiler_backend())
def _biased_grouped_topk_postprocess(
topk_ids, expert_location_dispatch_info, num_token_non_padded
):
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
return topk_ids
def biased_grouped_topk_gpu(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
correction_bias: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
routed_scaling_factor: Optional[float] = None,
apply_routed_scaling_factor_on_output: Optional[bool] = False,
):
num_tokens = gating_output.shape[0]
num_experts = gating_output.shape[1]
experts_per_group = (
num_experts // num_expert_group if num_expert_group else num_experts
)
# topk for routed experts only (shared experts are appended separately below)
topk_routed = topk - num_fused_shared_experts
if (
_is_cuda
and fused_topk_deepseek is not None
and is_power_of_two(num_experts)
# flashinfer constraints (applied to routed experts only)
and topk_routed <= 8
and topk_group <= num_expert_group
and topk_group * num_expert_group >= topk_routed
and (
(experts_per_group <= 32 and experts_per_group * topk_group <= 128)
if num_expert_group > 1
else num_experts <= 384
)
):
# Pre-allocate output tensors (flashinfer mutates them in-place)
topk_weights = torch.empty(
(num_tokens, topk_routed), dtype=torch.float32, device=gating_output.device
)
topk_ids = torch.empty(
(num_tokens, topk_routed), dtype=torch.int32, device=gating_output.device
)
# flashinfer always applies the scaling_factor internally
scaling_factor = 1.0
if routed_scaling_factor is not None and apply_routed_scaling_factor_on_output:
scaling_factor = routed_scaling_factor
# flashinfer's fused_topk_deepseek
fused_topk_deepseek(
gating_output.to(dtype=torch.float32),
correction_bias,
num_expert_group,
topk_group,
topk_routed,
scaling_factor,
topk_weights,
topk_ids,
True,
)
if num_fused_shared_experts > 0:
# Append shared expert columns: ID = num_experts (first shared slot),
# weight = sum(routed) / scaling_factor (matching biased_grouped_topk_impl).
# DeepEP fusion will overwrite both in _remap_topk_ids_for_deepep_fusion.
topk_ids = F.pad(topk_ids, (0, num_fused_shared_experts), value=num_experts)
topk_weights = F.pad(topk_weights, (0, num_fused_shared_experts))
if routed_scaling_factor is not None:
topk_weights[:, topk_routed:] = (
topk_weights[:, :topk_routed].sum(dim=-1, keepdim=True)
/ routed_scaling_factor
)
return topk_weights, topk_ids
elif (
_is_cuda
# moe_fused_gate kernel ensures that num_experts/num_expert_group does not exceed MAX_VPT=32 now. And when kernel can handle MAX_VPT > 32, we can remove this assertion.
and experts_per_group <= 32
and is_power_of_two(num_experts)
):
topk_weights, topk_ids = moe_fused_gate(
gating_output.to(dtype=torch.float32),
correction_bias,
num_expert_group,
topk_group,
topk,
num_fused_shared_experts,
routed_scaling_factor if routed_scaling_factor is not None else 1.0,
apply_routed_scaling_factor_on_output,
)
return topk_weights, topk_ids
elif _use_aiter:
assert not apply_routed_scaling_factor_on_output, "Not implemented"
token = gating_output.shape[0]
device = gating_output.device
assert (