|
29 | 29 | ) |
30 | 30 |
|
31 | 31 | import torch |
| 32 | +import torch.nn.functional as F |
32 | 33 |
|
33 | 34 | try: |
34 | 35 | from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing |
@@ -443,6 +444,25 @@ def scoring_func_impl(gating_output: torch.Tensor) -> torch.Tensor: |
443 | 444 | return topk_weights, topk_ids |
444 | 445 |
|
445 | 446 |
|
| 447 | +def fused_topk_softmax_torch_raw_logits( |
| 448 | + hidden_states: torch.Tensor, |
| 449 | + gating_output: torch.Tensor, |
| 450 | + topk: int, |
| 451 | + renormalize: bool, |
| 452 | +): |
| 453 | + assert ( |
| 454 | + hidden_states.shape[0] == gating_output.shape[0] |
| 455 | + ), f"Number of tokens mismatch, {hidden_states.shape=} vs {gating_output.shape=}" |
| 456 | + |
| 457 | + _, topk_ids = torch.topk(gating_output, k=topk, dim=-1, sorted=False) |
| 458 | + logits = gating_output.float() |
| 459 | + topk_weights = logits.gather(1, topk_ids) |
| 460 | + if renormalize: |
| 461 | + topk_weights = F.softmax(topk_weights, dim=-1, dtype=torch.float32) |
| 462 | + |
| 463 | + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) |
| 464 | + |
| 465 | + |
446 | 466 | def fused_topk_cpu( |
447 | 467 | hidden_states: torch.Tensor, |
448 | 468 | gating_output: torch.Tensor, |
@@ -1030,15 +1050,28 @@ def select_experts( |
1030 | 1050 | ) |
1031 | 1051 | elif custom_routing_function is None: |
1032 | 1052 | assert not apply_routed_scaling_factor_on_output, "Not implemented" |
1033 | | - # Qwen3MOE uses fused_topk |
1034 | | - topk_weights, topk_ids = fused_topk( |
1035 | | - hidden_states=hidden_states, |
1036 | | - gating_output=router_logits, |
1037 | | - topk=num_routed_topk if _use_aiter else top_k, |
1038 | | - renormalize=renormalize, |
1039 | | - correction_bias=correction_bias, |
1040 | | - scoring_func=scoring_func, |
1041 | | - ) |
| 1053 | + if ( |
| 1054 | + get_moe_runner_backend().is_flashinfer_trtllm_routed() |
| 1055 | + and scoring_func == "softmax" |
| 1056 | + and correction_bias is None |
| 1057 | + ): |
| 1058 | + # flashinfer_trtllm_routed uses raw-logits topk |
| 1059 | + topk_weights, topk_ids = fused_topk_softmax_torch_raw_logits( |
| 1060 | + hidden_states=hidden_states, |
| 1061 | + gating_output=router_logits, |
| 1062 | + topk=num_routed_topk if _use_aiter else top_k, |
| 1063 | + renormalize=renormalize, |
| 1064 | + ) |
| 1065 | + else: |
| 1066 | + # Qwen3MOE uses fused_topk |
| 1067 | + topk_weights, topk_ids = fused_topk( |
| 1068 | + hidden_states=hidden_states, |
| 1069 | + gating_output=router_logits, |
| 1070 | + topk=num_routed_topk if _use_aiter else top_k, |
| 1071 | + renormalize=renormalize, |
| 1072 | + correction_bias=correction_bias, |
| 1073 | + scoring_func=scoring_func, |
| 1074 | + ) |
1042 | 1075 | else: |
1043 | 1076 | assert ( |
1044 | 1077 | num_token_non_padded is None |
|
0 commit comments