Skip to content

[Fix] try fix feat/qwen35 cp by upd zigzag#957

Merged
Zhichenzzz merged 3 commits intofeat/qwen35_cpfrom
fix/feat/qwen35_cp
Apr 7, 2026
Merged

[Fix] try fix feat/qwen35 cp by upd zigzag#957
Zhichenzzz merged 3 commits intofeat/qwen35_cpfrom
fix/feat/qwen35_cp

Conversation

@guapisolo
Copy link
Copy Markdown
Collaborator

No description provided.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces hybrid Context Parallel (CP) support for GatedDeltaNet modules by implementing P2P exchange mechanisms to convert between zigzag and sequential layouts. Key additions include the _ZigzagSequentialExchange autograd function and utility functions for CP context management. Feedback highlights several critical issues: the layout conversion logic is hardcoded for a CP world size of 2, the sequence splitting mechanism fails to handle odd-length sequences, and the calculation of local sequence boundaries incorrectly assumes that sequence lengths are always multiples of the CP size.

Comment thread miles_plugins/models/hf_attention.py Outdated
Comment on lines +41 to +73
class _ZigzagSequentialExchange(torch.autograd.Function):
"""P2P exchange to convert between zigzag and sequential CP layouts.

For CP=2, zigzag rank 0 holds [sub_0, sub_3] and rank 1 holds [sub_1, sub_2].
Sequential rank 0 needs [sub_0, sub_1] and rank 1 needs [sub_2, sub_3].
This exchanges the misplaced sub-chunk between ranks via a single sendrecv.
"""

@staticmethod
def forward(ctx, send_buf, cp_group, cp_rank):
ctx.cp_group = cp_group
ctx.cp_rank = cp_rank
recv_buf = torch.empty_like(send_buf)
peer = 1 - cp_rank
if cp_rank == 0:
dist.send(send_buf.contiguous(), group_dst=peer, group=cp_group)
dist.recv(recv_buf, group_src=peer, group=cp_group)
else:
dist.recv(recv_buf, group_src=peer, group=cp_group)
dist.send(send_buf.contiguous(), group_dst=peer, group=cp_group)
return recv_buf

@staticmethod
def backward(ctx, grad_recv):
grad_send = torch.empty_like(grad_recv)
peer = 1 - ctx.cp_rank
if ctx.cp_rank == 0:
dist.send(grad_recv.contiguous(), group_dst=peer, group=ctx.cp_group)
dist.recv(grad_send, group_src=peer, group=ctx.cp_group)
else:
dist.recv(grad_send, group_src=peer, group=ctx.cp_group)
dist.send(grad_recv.contiguous(), group_dst=peer, group=ctx.cp_group)
return grad_send, None, None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The _ZigzagSequentialExchange class and its associated logic are hardcoded for a Context Parallel (CP) world size of 2. Specifically, the peer calculation peer = 1 - cp_rank (line 54) and the assumption that each rank holds exactly two sub-chunks per sequence will fail or produce incorrect results if cp_world_size > 2. Since this logic is triggered whenever hybrid_cp is enabled and cp_world_size > 1, it will cause runtime errors or silent data corruption in configurations with CP > 2.

References
  1. Avoid hardcoding model dimensions or parameters; derive them from configuration or input tensor shapes instead.

Comment thread miles_plugins/models/hf_attention.py Outdated
Comment on lines +76 to +105
def _zigzag_to_sequential(hidden_states, local_cu_seqlens, cp_group, cp_rank):
"""Convert zigzag CP layout to sequential via P2P exchange (CP=2 only).

Rank 0 zigzag: [sub_0, sub_3] → sequential: [sub_0, sub_1]
Rank 1 zigzag: [sub_1, sub_2] → sequential: [sub_2, sub_3]
"""
# Split each sample into ascending (first half) and descending (second half)
keep_parts, send_parts = [], []
for i in range(len(local_cu_seqlens) - 1):
start, end = local_cu_seqlens[i], local_cu_seqlens[i + 1]
mid = (start + end) // 2
if cp_rank == 0:
keep_parts.append(hidden_states[start:mid]) # sub_0
send_parts.append(hidden_states[mid:end]) # sub_3 → send to rank 1
else:
send_parts.append(hidden_states[start:mid]) # sub_1 → send to rank 0
keep_parts.append(hidden_states[mid:end]) # sub_2

send_buf = torch.cat(send_parts, dim=0)
recv_buf = _ZigzagSequentialExchange.apply(send_buf, cp_group, cp_rank)

# Reassemble: both ranks → [keep, recv]
result = []
offset = 0
for i in range(len(local_cu_seqlens) - 1):
chunk_len = (local_cu_seqlens[i + 1] - local_cu_seqlens[i]) // 2
result.append(keep_parts[i])
result.append(recv_buf[offset : offset + chunk_len])
offset += chunk_len
return torch.cat(result, dim=0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The _zigzag_to_sequential function (and similarly _sequential_to_zigzag) does not correctly handle sequences with odd lengths. The use of integer division mid = (start + end) // 2 (line 86) and chunk_len = ... // 2 (line 101) assumes that each sample is split into two equal halves. If a sample has an odd number of tokens in the local buffer (e.g., 5 tokens), mid will be 2, resulting in 2 tokens kept and 3 tokens sent. However, chunk_len will be 2, causing the reassembly logic to slice only 2 tokens from the received buffer, losing the 3rd token and corrupting the offsets for all subsequent samples in the packed batch.

if mpu.get_context_parallel_world_size() > 1 and not self.hybrid_cp:
if mpu.get_context_parallel_world_size() > 1 and self.hybrid_cp:
cp_size = mpu.get_context_parallel_world_size()
local_cu_seqlens = cu_seqlens // cp_size
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calculating local_cu_seqlens = cu_seqlens // cp_size assumes that every sequence boundary in the packed batch is a multiple of cp_size. In a packed sequence scenario, individual sequences can have arbitrary lengths. If a sequence length is not a multiple of cp_size, the calculated local boundaries will not align with the actual data distribution across CP ranks, leading to incorrect slicing in _zigzag_to_sequential and potential loss of tokens.

- Replace CP=2-only P2P exchange with general implementation
- Use batch_isend_irecv with group_peer for correct group-local ranks
- Support arbitrary CP sizes via sub-chunk routing table
- Update correctness test to pass global cu_seqlens
@Zhichenzzz Zhichenzzz merged commit ea7fafa into feat/qwen35_cp Apr 7, 2026
14 of 15 checks passed
@Zhichenzzz Zhichenzzz deleted the fix/feat/qwen35_cp branch April 7, 2026 23:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants