-
Notifications
You must be signed in to change notification settings - Fork 5.8k
Expand file tree
/
Copy pathmoe.py
More file actions
executable file
·289 lines (262 loc) · 7.36 KB
/
moe.py
File metadata and controls
executable file
·289 lines (262 loc) · 7.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
from typing import Optional
import torch
def moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_token_ids,
experts_ids,
num_tokens_post_pad,
cumsum_buffer,
pad_sorted_token_ids=False,
):
torch.ops.sgl_kernel.moe_align_block_size.default(
topk_ids,
num_experts,
block_size,
sorted_token_ids,
experts_ids,
num_tokens_post_pad,
cumsum_buffer,
pad_sorted_token_ids,
)
def topk_softmax(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool = False,
moe_softcapping: float = 0.0,
correction_bias: Optional[torch.Tensor] = None,
) -> None:
"""
Compute top-k softmax for MoE routing.
Args:
topk_weights: Output tensor for top-k weights [num_tokens, topk]
topk_ids: Output tensor for top-k expert indices [num_tokens, topk]
gating_output: Gating logits [num_tokens, num_experts]
renormalize: Whether to renormalize the top-k weights
moe_softcapping: Tanh softcapping value (0.0 to disable)
correction_bias: Per-expert bias correction [num_experts], must be float32 if provided
"""
torch.ops.sgl_kernel.topk_softmax.default(
topk_weights,
topk_ids,
gating_output,
renormalize,
moe_softcapping,
correction_bias,
)
def topk_sigmoid(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool = False,
correction_bias: Optional[torch.Tensor] = None,
) -> None:
"""
Compute top-k sigmoid for MoE routing.
Args:
topk_weights: Output tensor for top-k weights [num_tokens, topk]
topk_ids: Output tensor for top-k expert indices [num_tokens, topk]
gating_output: Gating logits [num_tokens, num_experts]
renormalize: Whether to renormalize the top-k weights
correction_bias: Per-expert bias correction [num_experts], must be float32 if provided
"""
torch.ops.sgl_kernel.topk_sigmoid.default(
topk_weights,
topk_ids,
gating_output,
renormalize,
correction_bias,
)
def moe_sum_reduce(
input_tensor,
output_tensor,
routed_scaling_factor=0,
):
torch.ops.sgl_kernel.moe_sum_reduce.default(
input_tensor,
output_tensor,
routed_scaling_factor,
)
def moe_sum(
input_tensor: torch.Tensor,
output_tensor: torch.Tensor,
):
torch.ops.sgl_kernel.moe_sum.default(
input_tensor,
output_tensor,
)
def moe_fused_gate(
input_tensor,
bias,
num_expert_group,
topk_group,
topk,
num_fused_shared_experts=0,
routed_scaling_factor=0,
apply_routed_scaling_factor_on_output=False,
):
# This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion
# it split group of expert into num_expert_group, and use top2 expert weight sum in each group
# as the group weight to select expert groups and then select topk experts within the selected groups
# the #experts is decided by the input tensor shape and we currently only support power of 2 #experts
# and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limited for now.
# for non-supported case, we suggest to use the biased_grouped_topk func in sglang.srt.layers.moe.topk
# num_fused_shared_experts: if > 0, the last several experts will be
# replaced with shared experts. the shared experts will be divided by the
# routed_scaling_factor - this is intended to cancel out later when routed+shared
# output is scaled so that shared experts are not scaled.
# routed_scaling_factor: if > 0, the experts will be scaled by this factor
# apply_routed_scaling_factor_on_output: if true, output will be
# scaled by the routed_scaling_factor
return torch.ops.sgl_kernel.moe_fused_gate.default(
input_tensor,
bias,
num_expert_group,
topk_group,
topk,
num_fused_shared_experts,
routed_scaling_factor,
apply_routed_scaling_factor_on_output,
)
def kimi_k2_moe_fused_gate(
input_tensor,
bias,
topk,
renormalize=True,
routed_scaling_factor=1.0,
apply_routed_scaling_factor_on_output=False,
):
"""
Simplified fused kernel for Kimi K2 model (num_expert_group=1).
This kernel removes the grouped topk logic since all experts belong to a single group.
Args:
input_tensor: Gating output tensor [num_tokens, num_experts]
bias: Correction bias tensor [num_experts]
topk: Number of experts to select per token
renormalize: Whether to renormalize the topk weights
routed_scaling_factor: Scaling factor for expert weights
apply_routed_scaling_factor_on_output: If true, apply scaling factor to output
Returns:
Tuple of (topk_weights, topk_ids)
- topk_weights: [num_tokens, topk] float32 tensor
- topk_ids: [num_tokens, topk] int32 tensor
"""
return torch.ops.sgl_kernel.kimi_k2_moe_fused_gate.default(
input_tensor,
bias,
topk,
renormalize,
routed_scaling_factor,
apply_routed_scaling_factor_on_output,
)
def fp8_blockwise_scaled_grouped_mm(
output,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a,
b,
scales_a,
scales_b,
stride_a,
stride_b,
stride_c,
layout_sfa,
layout_sfb,
problem_sizes,
expert_offsets,
workspace,
):
torch.ops.sgl_kernel.fp8_blockwise_scaled_grouped_mm.default(
output,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a,
b,
scales_a,
scales_b,
stride_a,
stride_b,
stride_c,
layout_sfa,
layout_sfb,
problem_sizes,
expert_offsets,
workspace,
)
def prepare_moe_input(
topk_ids,
expert_offsets,
problem_sizes1,
problem_sizes2,
input_permutation,
output_permutation,
num_experts,
n,
k,
blockscale_offsets: Optional[torch.Tensor] = None,
):
torch.ops.sgl_kernel.prepare_moe_input.default(
topk_ids,
expert_offsets,
blockscale_offsets,
problem_sizes1,
problem_sizes2,
input_permutation,
output_permutation,
num_experts,
n,
k,
)
def apply_shuffle_mul_sum(
input,
output,
permutation,
factors,
):
torch.ops.sgl_kernel.apply_shuffle_mul_sum.default(
input, output, permutation, factors
)
def fused_qk_norm_rope(
qkv: torch.Tensor,
num_heads_q: int,
num_heads_k: int,
num_heads_v: int,
head_dim: int,
eps: float,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
base: float,
is_neox: bool,
position_ids: torch.Tensor,
factor: float,
low: float,
high: float,
attention_factor: float,
rotary_dim: Optional[int] = None,
) -> None:
torch.ops.sgl_kernel.fused_qk_norm_rope(
qkv,
num_heads_q,
num_heads_k,
num_heads_v,
head_dim,
eps,
q_weight,
k_weight,
base,
is_neox,
position_ids,
factor,
low,
high,
attention_factor,
rotary_dim if rotary_dim is not None else head_dim,
)