Skip to content

[Feature] Add initial support for sequence parallelism#1436

Closed
Ying1123 wants to merge 1 commit intomainfrom
seq-parallel
Closed

[Feature] Add initial support for sequence parallelism#1436
Ying1123 wants to merge 1 commit intomainfrom
seq-parallel

Conversation

@Ying1123
Copy link
Copy Markdown
Contributor

@ZYHowell @ivanium Moved from #1041

@Ying1123 Ying1123 changed the title Add initial support for sequence parallelism [Feature] Add initial support for sequence parallelism Sep 16, 2024
@merrymercy merrymercy mentioned this pull request Sep 17, 2024
3 tasks
@Ying1123 Ying1123 marked this pull request as draft September 19, 2024 01:39
@merrymercy merrymercy mentioned this pull request Sep 22, 2024
37 tasks
@kuangdao
Copy link
Copy Markdown

From the code i see, the prefill stage after attention, the shape of output is [padded_total_num_tokens, q_head_num // SP_SIZE, head_dim], and then out * RowSeqParallelLinear which need use allreduce. the input of qkv_proj_linear is [padded_total_num_tokens, q_head_num, head_dim] which not spilted by sp_size. i want to know why done use ring attention , ring attention seems better then it in both Computing and Communication.

Comment on lines +273 to +284
For each SP worker, we have either (1) QKV of entire sequences:
q tensor: [padded_total_num_tokens, q_head_num // SP_SIZE, head_dim]
k tensor: [padded_total_num_tokens, k_head_num, head_dim]
v tensor: [padded_total_num_tokens, v_head_num, head_dim]
Or (2) Q of entire sequences and KV of the current SP shard:
q tensor: [padded_total_num_tokens, q_head_num // SP_SIZE, head_dim]
k tensor: [padded_sp_shard_num_tokens, k_head_num, head_dim]
v tensor: [padded_sp_shard_num_tokens, v_head_num, head_dim]

Case (1) saves cross-SP-worker communication, while case (2) saves computation
to get K and V for entire sequences but need computation in SP attn.
"""
Copy link
Copy Markdown
Contributor

@Edenzzzz Edenzzzz Dec 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(2) seems to be able to split workload and overlap even with single query. But just curious, does anyone have opinions on TreeAttention (just all-reduce lse instead of sending KV), which seems optimized for decoding?

Comment on lines +442 to +444
# TODO: in fact we can use all-to-all to gather the output and state here
# to collect only q head shards that are needed by the current SP worker.
# All-to-all will save communication and `merge_state` computation.
Copy link
Copy Markdown
Contributor

@Edenzzzz Edenzzzz Dec 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Later all-reduce in ColumnSeqParallelLinear ?

@fangtaosong
Copy link
Copy Markdown

fangtaosong commented Dec 31, 2024

Could this feature be integrated in the early months of 2025, and by the way, why not use the ring-attention which offers better performance? @merrymercy @Ying1123

@Edenzzzz
Copy link
Copy Markdown
Contributor

Edenzzzz commented Jan 1, 2025

Could this feature be integrated in the early months of 2025, and by the way, why not use the ring-attention which offers better performance? @merrymercy @Ying1123

This PR already implements ring attention.

@fangtaosong
Copy link
Copy Markdown

Could this feature be integrated in the early months of 2025, and by the way, why not use the ring-attention which offers better performance? @merrymercy @Ying1123

This PR already implements ring attention.

However, the code heavily relies on tensor-parallelism, and its layout [s, h//sp, d] seems to be inequivalent to flash-linear-attention in xDit [s//sp, h, d] in both computation and memory access.

@HJSang
Copy link
Copy Markdown

HJSang commented Jan 31, 2025

Is this still active? Looking forward for this change. It will be super helpful if we really want to handle long context.

@Edenzzzz
Copy link
Copy Markdown
Contributor

Edenzzzz commented Feb 5, 2025

However, the code heavily relies on tensor-parallelism, and its layout [s, h//sp, d] seems to be inequivalent to flash-linear-attention in xDit [s//sp, h, d] in both computation and memory access.

The reason for coupling with TP is probably that if we use pure ring attn, we can only replicate Q on all ranks, which causes redundant computation, while if we shard Q there's no redundancy.
I think this could be solved with Tree attention (replicate Q, all-reduce lse)

@zhaochenyang20 zhaochenyang20 mentioned this pull request Mar 3, 2025
22 tasks
@zhyncs zhyncs mentioned this pull request Mar 4, 2025
67 tasks
@Edenzzzz
Copy link
Copy Markdown
Contributor

Edenzzzz commented Mar 4, 2025

I'm interested in helping with finalizing this if possible :)

@merrymercy merrymercy closed this Apr 21, 2025
@zhyncs zhyncs deleted the seq-parallel branch July 22, 2025 05:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants