Skip to content

feat: real cp support with relayout fix for qwen3.5 train/rollout mismatch#885

Merged
Zhichenzzz merged 17 commits intomainfrom
feat/qwen35_cp
Apr 11, 2026
Merged

feat: real cp support with relayout fix for qwen3.5 train/rollout mismatch#885
Zhichenzzz merged 17 commits intomainfrom
feat/qwen35_cp

Conversation

@Zhichenzzz
Copy link
Copy Markdown
Contributor

@Zhichenzzz Zhichenzzz commented Apr 4, 2026

Summary

This PR addresses #878 , adding real Context Parallelism (CP) support and fixing a train/rollout mismatch bug that was overlooked in PrimeIntellect prime-rl#2080.

Problem

When using hybrid linear attention with CP, there is a silent but critical layout mismatch between training and rollout: flash-linear-attention uses a zig-zag layout for CP, while Megatron uses a packed layout. The approach taken in prime-rl #2080 did not account for this incompatibility, resulting in undetected logprob divergence during training.

Solution

We identified this root cause and implemented proper shard and reshard operations to correctly bridge the zig-zag ↔ packed layout between flash-linear-attention and Megatron under CP, ensuring full consistency between train and rollout.

Results

After 10 training steps, logprob diff across configurations:

Config logprob_diff
CP=1, EP=4 0.0121
CP=2, EP=4 0.0113
CP=2, EP=8 0.0105

With the shard/reshard fix applied, logprob diff is further reduced to ~0.001. Thanks @guapisolo for the great help!

References

- Upgrade flash-linear-attention to 0.4.2 (adds fla.ops.cp module)
- Pass cp_context to conv1d for correct boundary token handling
- Pass cp_context to chunk_gated_delta_rule for recurrent state CP
- Add use_native_cp flag to skip all-gather for DeltaNet layers
- Add setup_hybrid_cp() to configure GDN modules for native CP
- Add correctness test (torchrun) and e2e test for CP=2/4
- Add CP=2 EP=8 training script

Fixes: #878
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements native Context Parallel (CP) for GatedDeltaNet layers in Qwen3.5 models, updating the flash-linear-attention dependency and adding initialization logic for hybrid CP. It also refines Hugging Face configuration loading and validation, particularly for MoE models, and introduces comprehensive E2E and correctness tests. Feedback highlights a critical issue where the CP context construction assumes a batch size of one, which would cause failures in multi-sequence batches, and suggests more robust validation for MoE intermediate sizes by checking moe_intermediate_size instead of skipping the check entirely.

Comment thread miles/utils/arguments.py Outdated
Comment on lines +2097 to +2111
is_moe = hasattr(hf_config, "num_experts") or hasattr(hf_config, "moe_intermediate_size")
checks = [
("hidden_size", "hidden_size", equal),
("num_attention_heads", "num_attention_heads", equal),
("num_hidden_layers", "num_layers", equal),
("intermediate_size", "ffn_hidden_size", equal),
("tie_word_embeddings", "untie_embeddings_and_output_weights", lambda x, y: not x == y),
(
"rms_norm_eps",
"norm_epsilon" if os.getenv("DEPRECATED_MEGATRON_COMPATIBLE", "0") == "1" else "layernorm_epsilon",
equal,
),
("rope_theta", "rotary_base", equal),
]:
]
if not is_moe:
checks.insert(3, ("intermediate_size", "ffn_hidden_size", equal))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation skips the validation of the intermediate size for MoE models. For MoE models, ffn_hidden_size should be validated against moe_intermediate_size (if present) or intermediate_size. Hardcoding the skip for all MoE models reduces the robustness of the configuration check.

Suggested change
is_moe = hasattr(hf_config, "num_experts") or hasattr(hf_config, "moe_intermediate_size")
checks = [
("hidden_size", "hidden_size", equal),
("num_attention_heads", "num_attention_heads", equal),
("num_hidden_layers", "num_layers", equal),
("intermediate_size", "ffn_hidden_size", equal),
("tie_word_embeddings", "untie_embeddings_and_output_weights", lambda x, y: not x == y),
(
"rms_norm_eps",
"norm_epsilon" if os.getenv("DEPRECATED_MEGATRON_COMPATIBLE", "0") == "1" else "layernorm_epsilon",
equal,
),
("rope_theta", "rotary_base", equal),
]:
]
if not is_moe:
checks.insert(3, ("intermediate_size", "ffn_hidden_size", equal))
is_moe = hasattr(hf_config, "num_experts") or hasattr(hf_config, "moe_intermediate_size")
intermediate_size_attr = "moe_intermediate_size" if is_moe and hasattr(hf_config, "moe_intermediate_size") else "intermediate_size"
checks = [
("hidden_size", "hidden_size", equal),
("num_attention_heads", "num_attention_heads", equal),
("num_hidden_layers", "num_layers", equal),
(intermediate_size_attr, "ffn_hidden_size", equal),
("tie_word_embeddings", "untie_embeddings_and_output_weights", lambda x, y: not x == y),
(
"rms_norm_eps",
"norm_epsilon" if os.getenv("DEPRECATED_MEGATRON_COMPATIBLE", "0") == "1" else "layernorm_epsilon",
equal,
),
("rope_theta", "rotary_base", equal),
]
References
  1. Model parameters should be retrieved from the model configuration rather than being hardcoded.

Comment thread miles_plugins/models/qwen3_5.py Outdated
Comment on lines +90 to +101
def _build_cp_context(self, local_seq_len: int, device: torch.device):
"""Build fla CP context from the local (sharded) sequence length."""
cp_group = getattr(self, "cp_group", None)
if cp_group is None or build_cp_context is None:
return None
global_seq_len = local_seq_len * self.cp_world_size
global_cu_seqlens = torch.tensor([0, global_seq_len], dtype=torch.int32, device=device)
return build_cp_context(
cu_seqlens=global_cu_seqlens,
group=cp_group,
conv1d_kernel_size=self.conv_kernel_size,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The _build_cp_context method currently constructs global_cu_seqlens for only a single sequence, which will lead to incorrect behavior or crashes when the batch size is greater than 1. It should be updated to account for the batch size by creating a cu_seqlens tensor that covers all sequences in the batch.

    def _build_cp_context(self, batch_size: int, local_seq_len: int, device: torch.device):
        """Build fla CP context from the local (sharded) sequence length."""
        cp_group = getattr(self, "cp_group", None)
        if cp_group is None or build_cp_context is None:
            return None
        cp_world_size = getattr(self, "cp_world_size", 1)
        global_seq_len = local_seq_len * cp_world_size
        global_cu_seqlens = torch.arange(
            0, (batch_size + 1) * global_seq_len, step=global_seq_len, dtype=torch.int32, device=device
        )
        return build_cp_context(
            cu_seqlens=global_cu_seqlens,
            group=cp_group,
            conv1d_kernel_size=self.conv_kernel_size,
        )
References
  1. Avoid hardcoding model dimensions; derive them from configuration or input tensor shapes instead.

Comment thread miles_plugins/models/qwen3_5.py Outdated
):
batch_size, seq_len, _ = hidden_states.shape

cp_context = self._build_cp_context(seq_len, hidden_states.device)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The call to _build_cp_context needs to pass the batch_size to correctly initialize the CP context for multi-sequence batches.

Suggested change
cp_context = self._build_cp_context(seq_len, hidden_states.device)
cp_context = self._build_cp_context(batch_size, seq_len, hidden_states.device)

Comment thread miles_plugins/models/qwen3_next.py Outdated
Comment on lines +83 to +94
def _build_cp_context(self, local_seq_len: int, device: torch.device):
"""Build fla CP context from the local (sharded) sequence length."""
cp_group = getattr(self, "cp_group", None)
if cp_group is None or build_cp_context is None:
return None
global_seq_len = local_seq_len * self.cp_world_size
global_cu_seqlens = torch.tensor([0, global_seq_len], dtype=torch.int32, device=device)
return build_cp_context(
cu_seqlens=global_cu_seqlens,
group=cp_group,
conv1d_kernel_size=self.conv_kernel_size,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The _build_cp_context method ignores the batch size when constructing global_cu_seqlens. This will cause issues during Context Parallel execution if the batch size is greater than 1. The cu_seqlens should be generated based on the batch size.

    def _build_cp_context(self, batch_size: int, local_seq_len: int, device: torch.device):
        """Build fla CP context from the local (sharded) sequence length."""
        cp_group = getattr(self, "cp_group", None)
        if cp_group is None or build_cp_context is None:
            return None
        cp_world_size = getattr(self, "cp_world_size", 1)
        global_seq_len = local_seq_len * cp_world_size
        global_cu_seqlens = torch.arange(
            0, (batch_size + 1) * global_seq_len, step=global_seq_len, dtype=torch.int32, device=device
        )
        return build_cp_context(
            cu_seqlens=global_cu_seqlens,
            group=cp_group,
            conv1d_kernel_size=self.conv_kernel_size,
        )
References
  1. Avoid hardcoding model dimensions; derive them from configuration or input tensor shapes instead.

Comment thread miles_plugins/models/qwen3_next.py Outdated
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor = None,
):
cp_context = self._build_cp_context(hidden_states.shape[1], hidden_states.device)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Update the call to _build_cp_context to include the batch size, ensuring the CP context is correctly built for batches with more than one sequence.

Suggested change
cp_context = self._build_cp_context(hidden_states.shape[1], hidden_states.device)
cp_context = self._build_cp_context(hidden_states.shape[0], hidden_states.shape[1], hidden_states.device)

@Zhichenzzz Zhichenzzz marked this pull request as ready for review April 5, 2026 22:09
Comment thread miles/utils/arguments.py Outdated
if "rope_theta" in hf_config.rope_parameters:
hf_config.rope_theta = hf_config.rope_parameters["rope_theta"]

is_moe = hasattr(hf_config, "num_experts") or hasattr(hf_config, "moe_intermediate_size")
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@guapisolo I added this line to ensure the qwen 3.5 hf_config intermediate_size=5632 not equal with ffn_hidden_size

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me double check.

Copy link
Copy Markdown
Collaborator

@guapisolo guapisolo Apr 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove this line? This line is not related to your motivation.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious that around few weeks before when I integrated qwen3.5, there was no assert errors here. (what about you I believe you also tested it before) is it from transformer version update or sth else?

Comment thread miles/utils/arguments.py Outdated
("rope_theta", "rotary_base", equal),
]:
if hasattr(hf_config, hf_config_name):
if is_moe and hf_config_name == "intermediate_size":
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On my side this line might be redundant? https://huggingface.co/Qwen/Qwen3.5-35B-A3B/blob/main/config.json . There is no intermediate_size=5632 in qwen3.5

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh i found it... we should update into pip install transformers==5.2.0.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I add a walkaround to avoid this under transformers 4.57.1

@Zhichenzzz Zhichenzzz requested a review from guapisolo April 7, 2026 02:58
Copy link
Copy Markdown
Collaborator

@guapisolo guapisolo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. On my side, Qwen-3.5-35b-a3b, cp=2 logprob diff 0.013
Need code owner review.

Comment thread scripts/run-qwen3.5-35B-A3B-mtp-cp2-ep8.sh Outdated
Comment thread miles/backends/training_utils/cp_utils.py Outdated
Comment thread miles/backends/megatron_utils/actor.py Outdated
@Zhichenzzz Zhichenzzz requested a review from yueming-yuan April 11, 2026 00:11
@Zhichenzzz Zhichenzzz changed the title feat: hybrid cp for qwen3.5 rl feat: real cp support with zig-zag/packed layout fix for qwen3.5 train/rollout mismatch Apr 11, 2026
@Zhichenzzz Zhichenzzz changed the title feat: real cp support with zig-zag/packed layout fix for qwen3.5 train/rollout mismatch feat: real cp support with relayout fix for qwen3.5 train/rollout mismatch Apr 11, 2026
@Zhichenzzz Zhichenzzz merged commit eb294e3 into main Apr 11, 2026
17 checks passed
@Zhichenzzz Zhichenzzz deleted the feat/qwen35_cp branch April 11, 2026 03:06
GuanxingLu pushed a commit to GuanxingLu/miles that referenced this pull request Apr 21, 2026
DavidBellamy added a commit to LLM360/miles that referenced this pull request Apr 21, 2026
…region clusters (#10)

* Revert "[BUGFIX] [P2PRDMA] Add rollout post-processing after P2PRDMA weight updates" (radixark#882)

* [Fix] fix ci (radixark#894)

* Avoid threading for ray getting object (radixark#886)

* Add explicit errors for unsupported Megatron profiles (radixark#887)

* Add nvfp4 quantizer files (radixark#907)

* Bump flash-linear-attention version to 0.4.2 (radixark#892)

* [BUGFIX] Invoke "post_process_quantization" by default after weight updating (radixark#890)

Co-authored-by: Yueming Yuan <yym022502@gmail.com>

* Add heartbeat and id to session server (radixark#866)

* fix: adding thin glm5 image to docker build + latest tag sync (radixark#871)

* Add consistent hashing routing policy for rollout (radixark#891)

Co-authored-by: Yueming Yuan <yueming@Mac.attlocal.net>

* [example] add retool v2 example with multi-turn framework interfaces (radixark#654)

Co-authored-by: GuanxingLu <gxlu02@gmail.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* Expose rollout-batch-size, n-samples-per-prompt, global-batch-size as CLI args in swe-agent-v2 (radixark#954)

Co-authored-by: Shi Dong <shi.dong@radixark.ai>

* chore: remove obsolete swe-agent server.py and run-qwen3.sh (radixark#952)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* Add weight staleness control for fully async rollout (radixark#958)

* Fix/pause generation mode (radixark#924)

Co-authored-by: Yueming Yuan <yym022502@gmail.com>

* [v0.5.10][1] Bump sglang to v0.5.10 (radixark#898)

* [v0.5.10][2] Fix apply_chat_template behavior for transformers >=5.0 (radixark#926)

Co-authored-by: guapisolo <guapisolo@gmail.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* [v0.5.10][3] Fix processor return_tensors duplicate kwarg for transformers >=5.0 (radixark#927)

Co-authored-by: guapisolo <guapisolo@gmail.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* [v0.5.10][4] Fix _no_split_modules set not subscriptable in transformers >=5.0 (radixark#931)

* [v0.5.10][5] Disable piecewise cuda graph to avoid NVLS oom (radixark#935)

* [v0.5.10][6][FSDP] fix outdated weight update logic in FSDP (radixark#948)

Co-authored-by: guapisolo <guapisolo@gmail.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: maocheng23 <35615230+maocheng23@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* [v0.5.10][7][FSDP] move FSDP to experimental and disable by default (radixark#961)

* Add skiplist and more robust calculation on val (radixark#965)

* [fix] tiny fix debug rollout only in weight version check (radixark#967)

* feat: real cp support with relayout fix for qwen3.5 train/rollout mismatch (radixark#885)

* [AMD] Upgrade to sglv0.5.10 (radixark#973)

* switch model to actor (radixark#756)

* [fix] support general logic to bypass fp32 downcast and fix qwen35 A_log dtype (radixark#975)

Co-authored-by: yueming-yuan <yym022502@gmail.com>

* fix: populate prefix_cache_info in OpenAI/session rollout path (radixark#960)

* Remove prepare_harbor_tasks.py; use harbor-private adapters (radixark#982)

* [fix] Skip flush_cache in in_place mode and add fully async example (radixark#974)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* GLM47 full cmd for async and sync reasoning (radixark#986)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix: handle non-tool appended messages in TITO incremental tokenization (radixark#949)

Co-authored-by: Yanbin Jiang <jybsuper@gmail.com>

* [docker] Add sgl-model-gateway install and download .tar.gz assets (radixark#895)

* [ci] fix hf rate limit error by caching tokenizer loading (radixark#1014)

Co-authored-by: maocheng23 <35615230+maocheng23@users.noreply.github.com>

* Use load_generate_function in legacy sglang_rollout path (radixark#1016)

* Update CODEOWNERS to add new reviewers (radixark#1021)

* Support moe lora for gpt-oss (radixark#798)

Co-authored-by: Ethan (Yusheng) Su <yushengsu.thu@gmail.com>

* [fix] restore expert_bias to fp32 before bridge weight export (radixark#811)

* [chore] drop legacy transformers upgrade pin for glm47-flash and qwen35 (radixark#1018)

* [fix] Enforce param dtype before wrap ddp (radixark#992)

Co-authored-by: Zhichen Zeng <zczeng@uw.edu>

* [upgrade] update Megatron-Bridge source and LoRA CI to megatron e2e tests and  (radixark#1023)

* [CI] Drop --use-miles-router from R3 tests and add r3 comparasion test between sgl & miles router (radixark#1015)

* wandb: raise init_timeout, add retry wrapper, fix shared-mode init for cross-region clusters

In online + shared mode, both `init_wandb_primary` and `init_wandb_secondary`
make HTTPS round-trips to wandb cloud (login + run create/attach). On
high-latency cross-region clusters (e.g. Abu Dhabi MBZUAI ↔ wandb-cloud
US-West) with concurrent actor bursts, a single round-trip can exceed the
wandb SDK's 90s default `init_timeout` — tearing down the whole run
with a silent handshake abort. Observed on RL360 job 1564420, which
forced `WANDB_MODE=offline` as a global default ever since (see
https://github.com/LLM360/RL360/issues/87).

The issue's original diagnosis assumed a local primary↔secondary socket
handshake race. That's not how shared mode works — per wandb's own
feature PR (wandb/wandb#6882), each writer spawns
an independent wandb-core that talks to the cloud directly; aggregation
is server-side by run_id. No local socket exists. The failure mode is
pure network/latency, not a local readiness race.

Changes
-------

- Bump `init_timeout` to 300s for primary and secondary Settings.
  Configurable via `WANDB_INIT_TIMEOUT_SECS` env var for tuning.
- Wrap both init paths in a bounded exponential-backoff retry
  (`_wandb_init_with_retry`) that re-attempts on wandb.errors.CommError
  and wandb.errors.UsageError. 3 attempts with 5→10→20s backoff by
  default, tunable via `WANDB_INIT_RETRY_ATTEMPTS` /
  `WANDB_INIT_RETRY_BACKOFF_SECS`.
- Add `x_label` tagging per wandb distributed-training docs: primary
  gets `rank_<rank>_primary`, secondaries get `rank_<rank>_secondary`.
  Enables per-rank console-log filtering in the wandb UI.
- Drop `reinit=True` from secondary init_kwargs. Shared mode natively
  supports concurrent writers on a single run; `reinit=True` triggered
  stale-state warnings on secondary actors without functional benefit.

Followups this change enables
-----------------------------

- `WANDB_MODE=offline` can be removed from scale.yaml's extra_env
  default once a pilot run confirms online mode boots cleanly.
- The tmux-based `~/bin/wandb-sync-rl360.sh` workaround on David's M2
  account becomes obsolete (no more offline-only default).
- Near-realtime wandb dashboards replace the ~2-minute-lag offline
  sync; per-rank system metrics via x_label filtering.

---------

Co-authored-by: JD <jaedon.guo@gmail.com>
Co-authored-by: Ethan (Yusheng) Su <yushengsu.thu@gmail.com>
Co-authored-by: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com>
Co-authored-by: Ziang Li <ziangli@umich.edu>
Co-authored-by: Zhichen Zeng <zczeng@uw.edu>
Co-authored-by: JensenFire <xinji1@microsoft.com>
Co-authored-by: Yueming Yuan <yym022502@gmail.com>
Co-authored-by: maocheng23 <35615230+maocheng23@users.noreply.github.com>
Co-authored-by: Douglas Yang <douglasyang88@gmail.com>
Co-authored-by: Yueming Yuan <yueming@Mac.attlocal.net>
Co-authored-by: Huapeng Zhou <73010314+PopSoda2002@users.noreply.github.com>
Co-authored-by: GuanxingLu <gxlu02@gmail.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: Shi-Dong <Shi-Dong@users.noreply.github.com>
Co-authored-by: Shi Dong <shi.dong@radixark.ai>
Co-authored-by: Jiajun Li <48857426+guapisolo@users.noreply.github.com>
Co-authored-by: guapisolo <guapisolo@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Yuzhen Zhou <82826991+zyzshishui@users.noreply.github.com>
Co-authored-by: Yanbin Jiang <jybsuper@gmail.com>
Co-authored-by: Ying Sheng <sqy1415@gmail.com>
Co-authored-by: Yisheng Gong <yishenggong9437@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature] Support real CP for Qwen 3.5

3 participants