[Fix] try fix feat/qwen35 cp by upd zigzag#957
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces hybrid Context Parallel (CP) support for GatedDeltaNet modules by implementing P2P exchange mechanisms to convert between zigzag and sequential layouts. Key additions include the _ZigzagSequentialExchange autograd function and utility functions for CP context management. Feedback highlights several critical issues: the layout conversion logic is hardcoded for a CP world size of 2, the sequence splitting mechanism fails to handle odd-length sequences, and the calculation of local sequence boundaries incorrectly assumes that sequence lengths are always multiples of the CP size.
| class _ZigzagSequentialExchange(torch.autograd.Function): | ||
| """P2P exchange to convert between zigzag and sequential CP layouts. | ||
|
|
||
| For CP=2, zigzag rank 0 holds [sub_0, sub_3] and rank 1 holds [sub_1, sub_2]. | ||
| Sequential rank 0 needs [sub_0, sub_1] and rank 1 needs [sub_2, sub_3]. | ||
| This exchanges the misplaced sub-chunk between ranks via a single sendrecv. | ||
| """ | ||
|
|
||
| @staticmethod | ||
| def forward(ctx, send_buf, cp_group, cp_rank): | ||
| ctx.cp_group = cp_group | ||
| ctx.cp_rank = cp_rank | ||
| recv_buf = torch.empty_like(send_buf) | ||
| peer = 1 - cp_rank | ||
| if cp_rank == 0: | ||
| dist.send(send_buf.contiguous(), group_dst=peer, group=cp_group) | ||
| dist.recv(recv_buf, group_src=peer, group=cp_group) | ||
| else: | ||
| dist.recv(recv_buf, group_src=peer, group=cp_group) | ||
| dist.send(send_buf.contiguous(), group_dst=peer, group=cp_group) | ||
| return recv_buf | ||
|
|
||
| @staticmethod | ||
| def backward(ctx, grad_recv): | ||
| grad_send = torch.empty_like(grad_recv) | ||
| peer = 1 - ctx.cp_rank | ||
| if ctx.cp_rank == 0: | ||
| dist.send(grad_recv.contiguous(), group_dst=peer, group=ctx.cp_group) | ||
| dist.recv(grad_send, group_src=peer, group=ctx.cp_group) | ||
| else: | ||
| dist.recv(grad_send, group_src=peer, group=ctx.cp_group) | ||
| dist.send(grad_recv.contiguous(), group_dst=peer, group=ctx.cp_group) | ||
| return grad_send, None, None |
There was a problem hiding this comment.
The _ZigzagSequentialExchange class and its associated logic are hardcoded for a Context Parallel (CP) world size of 2. Specifically, the peer calculation peer = 1 - cp_rank (line 54) and the assumption that each rank holds exactly two sub-chunks per sequence will fail or produce incorrect results if cp_world_size > 2. Since this logic is triggered whenever hybrid_cp is enabled and cp_world_size > 1, it will cause runtime errors or silent data corruption in configurations with CP > 2.
References
- Avoid hardcoding model dimensions or parameters; derive them from configuration or input tensor shapes instead.
| def _zigzag_to_sequential(hidden_states, local_cu_seqlens, cp_group, cp_rank): | ||
| """Convert zigzag CP layout to sequential via P2P exchange (CP=2 only). | ||
|
|
||
| Rank 0 zigzag: [sub_0, sub_3] → sequential: [sub_0, sub_1] | ||
| Rank 1 zigzag: [sub_1, sub_2] → sequential: [sub_2, sub_3] | ||
| """ | ||
| # Split each sample into ascending (first half) and descending (second half) | ||
| keep_parts, send_parts = [], [] | ||
| for i in range(len(local_cu_seqlens) - 1): | ||
| start, end = local_cu_seqlens[i], local_cu_seqlens[i + 1] | ||
| mid = (start + end) // 2 | ||
| if cp_rank == 0: | ||
| keep_parts.append(hidden_states[start:mid]) # sub_0 | ||
| send_parts.append(hidden_states[mid:end]) # sub_3 → send to rank 1 | ||
| else: | ||
| send_parts.append(hidden_states[start:mid]) # sub_1 → send to rank 0 | ||
| keep_parts.append(hidden_states[mid:end]) # sub_2 | ||
|
|
||
| send_buf = torch.cat(send_parts, dim=0) | ||
| recv_buf = _ZigzagSequentialExchange.apply(send_buf, cp_group, cp_rank) | ||
|
|
||
| # Reassemble: both ranks → [keep, recv] | ||
| result = [] | ||
| offset = 0 | ||
| for i in range(len(local_cu_seqlens) - 1): | ||
| chunk_len = (local_cu_seqlens[i + 1] - local_cu_seqlens[i]) // 2 | ||
| result.append(keep_parts[i]) | ||
| result.append(recv_buf[offset : offset + chunk_len]) | ||
| offset += chunk_len | ||
| return torch.cat(result, dim=0) |
There was a problem hiding this comment.
The _zigzag_to_sequential function (and similarly _sequential_to_zigzag) does not correctly handle sequences with odd lengths. The use of integer division mid = (start + end) // 2 (line 86) and chunk_len = ... // 2 (line 101) assumes that each sample is split into two equal halves. If a sample has an odd number of tokens in the local buffer (e.g., 5 tokens), mid will be 2, resulting in 2 tokens kept and 3 tokens sent. However, chunk_len will be 2, causing the reassembly logic to slice only 2 tokens from the received buffer, losing the 3rd token and corrupting the offsets for all subsequent samples in the packed batch.
| if mpu.get_context_parallel_world_size() > 1 and not self.hybrid_cp: | ||
| if mpu.get_context_parallel_world_size() > 1 and self.hybrid_cp: | ||
| cp_size = mpu.get_context_parallel_world_size() | ||
| local_cu_seqlens = cu_seqlens // cp_size |
There was a problem hiding this comment.
Calculating local_cu_seqlens = cu_seqlens // cp_size assumes that every sequence boundary in the packed batch is a multiple of cp_size. In a packed sequence scenario, individual sequences can have arbitrary lengths. If a sequence length is not a multiple of cp_size, the calculated local boundaries will not align with the actual data distribution across CP ranks, leading to incorrect slicing in _zigzag_to_sequential and potential loss of tokens.
- Replace CP=2-only P2P exchange with general implementation - Use batch_isend_irecv with group_peer for correct group-local ranks - Support arbitrary CP sizes via sub-chunk routing table - Update correctness test to pass global cu_seqlens
No description provided.