moe fused_gate_kernel support n_share_experts_fusion param#5125
moe fused_gate_kernel support n_share_experts_fusion param#5125
Conversation
| for (int ii = 0; ii < topk; ++ii) { | ||
| int64_t const idx = topk * thread_row + ii; | ||
| output_ptr[idx] = static_cast<float>(static_cast<T>(output_ptr[idx]) / static_cast<T>(output_sum)); | ||
| output_ptr[idx] = output_ptr[idx] / output_sum; |
There was a problem hiding this comment.
wondering why we remove this static cast
There was a problem hiding this comment.
output_ptr itself is of type float, and output_sum is also of type float. This static_cast is completely unnecessary.
There was a problem hiding this comment.
I see, then wonder why the original code was like that...
My guess is they want to emulate some accuracy loss, but that looks pretty weird as well...
fzyzcjy
left a comment
There was a problem hiding this comment.
briefly checked and only some nits
| num_expert_group: int = 0, | ||
| topk_group: int = 0, | ||
| n_share_experts_fusion: int = 0, | ||
| routed_scaling_factor: float = 2.5, |
There was a problem hiding this comment.
nit: wondering whether we should make it default as nan or None, because 2.5 is very specific to DeepSeek. If in some place of code we accidentally forget to pass in the variable, we will get a 2.5 and hard to debug. On the other hand, if we set to none, then if one day another model enables fusion but we forget to pass in value, we will get a big error, and we immediately know it is a bug from where.
| n_share_experts_fusion = 0 | ||
| if global_server_args_dict["n_share_experts_fusion"] is not None: | ||
| n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"] | ||
| routed_scaling_factor = global_server_args_dict["routed_scaling_factor"] |
There was a problem hiding this comment.
hmm, wondering whether it is good to have it as a global variable... what about passing in as a normal variable, especially because select_experts does add this arg
| warmups: Optional[str] = None | ||
| n_share_experts_fusion: int = 0 | ||
| n_share_experts_fusion: Optional[int] = None | ||
| routed_scaling_factor: Optional[float] = None |
There was a problem hiding this comment.
nit: it seems that the scaling factor will be picked by the model config, so maybe we do not need to pass in as an arg
| ) | ||
| elif n_share_experts_fusion is None: | ||
| n_share_experts_fusion = tp_size | ||
| global_server_args_dict["n_share_experts_fusion"] = tp_size |
There was a problem hiding this comment.
nit: it seems this function is impure (indeed half pure and half impure), what about making it a pure function, i.e. do not do mutations inside it to reduce cognitive overhead
| } else if (thread_group_idx == 0) { | ||
| // If not using shared experts, add the last weight to output_sum | ||
| output_sum += output_ptr[topk * thread_row + (topk - 1)]; | ||
| } |
There was a problem hiding this comment.
About ////////////////////// Topk ////////////////////// section: it seems we are still spending time computing real_topk+num_fusion things. For example, for llama4, it may be top2 instead of top1.
Thus what about
- Suppose we remove name
topk, and havetopk_excluding_share_expert_fusionortopk_with_shared_expert_fusionetc (to avoid users getting confused which one is which one) - In the
////////////////////// Topk //////////////////////section above, we use thetopk_excluding_share_expert_fusionto compute everything (thus only top1 in llama4) - Then here we do not need this else if as well
| for (int ii = 0; ii < topk; ++ii) { | ||
| int64_t const idx = topk * thread_row + ii; | ||
| output_ptr[idx] = static_cast<float>(static_cast<T>(output_ptr[idx]) / static_cast<T>(output_sum)); | ||
| output_ptr[idx] = output_ptr[idx] / output_sum; |
There was a problem hiding this comment.
I see, then wonder why the original code was like that...
My guess is they want to emulate some accuracy loss, but that looks pretty weird as well...
| valid_max = num_experts + n_share_experts_fusion | ||
| shared_indices = original_indices[:, -1] | ||
| shared_ref_indices = original_ref_indices[:, -1] | ||
| if shared_indices is not None: |
There was a problem hiding this comment.
nit: wondering when it will be none, since it looks like original_indices[:, -1] which is a tensor
| @@ -57,6 +59,8 @@ __device__ void moe_fused_gate_impl( | |||
| int64_t num_rows, | |||
| int64_t topk_group, | |||
There was a problem hiding this comment.
question: is it mathematically incorrect, for the shared expert fusion thing, if we do topk+=n_shared_experts_fusion, but keeping topk_group unchanged? I have not digged into how it is computed, so just a quick question. If so, the quickest way to fix it may be applying the nits above by letting the topk logic compute topk_excluding_that_shared_expert_fusion instead of topk_including_that
|
Using this latest submission, combined with my added NextN (fused shared experts opened) on dual A800 node environment, the accuracy scores are as follows: |
Motivation
moe
fused_gate_kernelsupport n_share_experts_fusion paramUnitTest
GSM8K benchmark
Benchmark In H200
main branch(fused shared experts closed )
python3 -m sglang.launch_server --model /DeepSeek-V3 --tp 8 --trust-remote-code --port 30001 --disable-shared-experts-fusion python3 -m sglang.bench_serving --backend sglang --num-prompts 300 --request-rate 4 --port 30001 ============ Serving Benchmark Result ============ Backend: sglang Traffic request rate: 4.0 Max reqeuest concurrency: not set Successful requests: 300 Benchmark duration (s): 128.59 Total input tokens: 95293 Total generated tokens: 60411 Total generated tokens (retokenized): 60143 Request throughput (req/s): 2.33 Input token throughput (tok/s): 741.08 Output token throughput (tok/s): 469.81 Total token throughput (tok/s): 1210.89 Concurrency: 62.43 ----------------End-to-End Latency---------------- Mean E2E Latency (ms): 26757.12 Median E2E Latency (ms): 20392.68 ---------------Time to First Token---------------- Mean TTFT (ms): 811.22 Median TTFT (ms): 588.20 P99 TTFT (ms): 3221.68 ---------------Inter-Token Latency---------------- Mean ITL (ms): 129.83 Median ITL (ms): 51.89 P95 ITL (ms): 620.79 P99 ITL (ms): 1170.19 Max ITL (ms): 2894.62 ==================================================main branch(fused shared experts opened )
python3 -m sglang.launch_server --model /DeepSeek-V3 --tp 8 --trust-remote-code --port 30001 python3 -m sglang.bench_serving --backend sglang --num-prompts 300 --request-rate 4 --port 30001 ============ Serving Benchmark Result ============ Backend: sglang Traffic request rate: 4.0 Max reqeuest concurrency: not set Successful requests: 300 Benchmark duration (s): 125.97 Total input tokens: 95293 Total generated tokens: 60411 Total generated tokens (retokenized): 60149 Request throughput (req/s): 2.38 Input token throughput (tok/s): 756.50 Output token throughput (tok/s): 479.59 Total token throughput (tok/s): 1236.09 Concurrency: 64.55 ----------------End-to-End Latency---------------- Mean E2E Latency (ms): 27102.78 Median E2E Latency (ms): 20786.24 ---------------Time to First Token---------------- Mean TTFT (ms): 3336.97 Median TTFT (ms): 569.67 P99 TTFT (ms): 18203.60 ---------------Inter-Token Latency---------------- Mean ITL (ms): 118.92 Median ITL (ms): 51.17 P95 ITL (ms): 559.23 P99 ITL (ms): 652.11 Max ITL (ms): 22040.79 ==================================================pr(fused shared experts opened and use
moe_fused_gatekernel):python3 -m sglang.launch_server --model /DeepSeek-V3 --tp 8 --trust-remote-code --port 30001 python3 -m sglang.bench_serving --backend sglang --num-prompts 300 --request-rate 4 --port 30001 ============ Serving Benchmark Result ============ Backend: sglang Traffic request rate: 4.0 Max reqeuest concurrency: not set Successful requests: 300 Benchmark duration (s): 121.25 Total input tokens: 95293 Total generated tokens: 60411 Total generated tokens (retokenized): 60140 Request throughput (req/s): 2.47 Input token throughput (tok/s): 785.92 Output token throughput (tok/s): 498.23 Total token throughput (tok/s): 1284.15 Concurrency: 58.32 ----------------End-to-End Latency---------------- Mean E2E Latency (ms): 23570.69 Median E2E Latency (ms): 18271.27 ---------------Time to First Token---------------- Mean TTFT (ms): 2831.89 Median TTFT (ms): 420.14 P99 TTFT (ms): 19307.38 ---------------Inter-Token Latency---------------- Mean ITL (ms): 103.78 Median ITL (ms): 49.22 P95 ITL (ms): 253.46 P99 ITL (ms): 441.31 Max ITL (ms): 18028.75 ==================================================Total token throughput (tok/s):
main branch(fused shared experts closed ) 1210.89
main branch(fused shared experts opened ) 1236.09
pr(fused shared experts opened and use
moe_fused_gatekernel) 1284.15