[Feature] Add initial support for sequence parallelism#1436
[Feature] Add initial support for sequence parallelism#1436
Conversation
c263cb3 to
71c8afe
Compare
|
From the code i see, the prefill stage after attention, the shape of output is [padded_total_num_tokens, q_head_num // SP_SIZE, head_dim], and then out * RowSeqParallelLinear which need use allreduce. the input of qkv_proj_linear is [padded_total_num_tokens, q_head_num, head_dim] which not spilted by sp_size. i want to know why done use ring attention , ring attention seems better then it in both Computing and Communication. |
| For each SP worker, we have either (1) QKV of entire sequences: | ||
| q tensor: [padded_total_num_tokens, q_head_num // SP_SIZE, head_dim] | ||
| k tensor: [padded_total_num_tokens, k_head_num, head_dim] | ||
| v tensor: [padded_total_num_tokens, v_head_num, head_dim] | ||
| Or (2) Q of entire sequences and KV of the current SP shard: | ||
| q tensor: [padded_total_num_tokens, q_head_num // SP_SIZE, head_dim] | ||
| k tensor: [padded_sp_shard_num_tokens, k_head_num, head_dim] | ||
| v tensor: [padded_sp_shard_num_tokens, v_head_num, head_dim] | ||
|
|
||
| Case (1) saves cross-SP-worker communication, while case (2) saves computation | ||
| to get K and V for entire sequences but need computation in SP attn. | ||
| """ |
There was a problem hiding this comment.
(2) seems to be able to split workload and overlap even with single query. But just curious, does anyone have opinions on TreeAttention (just all-reduce lse instead of sending KV), which seems optimized for decoding?
| # TODO: in fact we can use all-to-all to gather the output and state here | ||
| # to collect only q head shards that are needed by the current SP worker. | ||
| # All-to-all will save communication and `merge_state` computation. |
There was a problem hiding this comment.
Later all-reduce in ColumnSeqParallelLinear ?
|
Could this feature be integrated in the early months of 2025, and by the way, why not use the ring-attention which offers better performance? @merrymercy @Ying1123 |
This PR already implements ring attention. |
However, the code heavily relies on tensor-parallelism, and its layout [s, h//sp, d] seems to be inequivalent to flash-linear-attention in xDit [s//sp, h, d] in both computation and memory access. |
|
Is this still active? Looking forward for this change. It will be super helpful if we really want to handle long context. |
The reason for coupling with TP is probably that if we use pure ring attn, we can only replicate Q on all ranks, which causes redundant computation, while if we shard Q there's no redundancy. |
|
I'm interested in helping with finalizing this if possible :) |
@ZYHowell @ivanium Moved from #1041