Optimize Permute Kernel in DeepEP#4643
Conversation
|
|
||
| def deepep_run_moe_deep_preprocess(topk_ids: torch.Tensor, num_experts: int): | ||
| reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True) | ||
| seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64) |
There was a problem hiding this comment.
It can be init using torch.empty
| deepep_compute_src2dst_triton_kernel[grid]( | ||
| reorder_ids, src2dst, topk_ids.numel(), num_minus_one, BLOCK_SIZE | ||
| ) | ||
| # src2dst -= num_minus_one |
|
|
||
|
|
||
| @triton.jit | ||
| def compute_src2dst_triton_kernel( |
There was a problem hiding this comment.
compute_src2dst_triton_kernel and deepep_compute_src2dst_triton_kernel are defined twice.
|
|
||
|
|
||
| @triton.jit | ||
| def deepep_compute_src2dst_triton_kernel( |
There was a problem hiding this comment.
Why developing a triton kernel is necessary? Is it faster?
|
|
||
|
|
||
| @triton.jit | ||
| def deepep_permute_triton_kernel( |
|
|
||
|
|
||
| @triton.jit | ||
| def deepep_post_reorder_triton_kernel( |
| output = torch.zeros( | ||
| (num_tokens, hidden_states.shape[1]), | ||
| device=hidden_states.device, | ||
| dtype=hidden_states.dtype, | ||
| ) |
| ) | ||
| if self.tp_size > 1: | ||
| recv_hidden_states, topk_idx, topk_weights, tokens_per_expert = ( | ||
| recv_hidden_states, reorder_topk_ids, seg_indptr = ( |
There was a problem hiding this comment.
Should we add some short comments on the meaning/examples of reorder_topk_ids and seg_indptr for readability?
|
Will there be further optimization plans for this permute kernel? |
We will continue to optimize the permute kernel, but it is not our top priority at the moment. |
The observed issue could potentially be attributed to ROCE network configuration. To verify this hypothesis, we recommend running the inter-node communication test from DeepEP's validation suite, specifically the internode connectivity check |


Motivation
The current performance of DeepEP is suboptimal due to the low efficiency of PyTorch's native permute function, which is used for formatting data before and after DeepEP communication. To address this limitation, we have implemented high-efficiency Triton kernels that significantly improve overall performance.
Performance on H20
Single Node
Command
Multi Node
Command
Modifications
Checklist