Skip to content

[dev] [4/5] Qwen3.5 support: Interleaved MRoPE layout#4750

Merged
BestJuly merged 4 commits into
NVIDIA:devfrom
wplf:feat/mrope-interleaved-layout
May 15, 2026
Merged

[dev] [4/5] Qwen3.5 support: Interleaved MRoPE layout#4750
BestJuly merged 4 commits into
NVIDIA:devfrom
wplf:feat/mrope-interleaved-layout

Conversation

@wplf

@wplf wplf commented May 12, 2026

Copy link
Copy Markdown
Member

Qwen3.5 support series

This is part of a 5-PR series adding Qwen3.5-VL support, split for review clarity.

Dev PRs (this series):

Main PRs (corresponding mirrors):


Summary

Add the interleaved MRoPE layout used by Qwen3.5-VL, gated by a new config flag.

  • New TransformerConfig.mrope_interleaved: bool = False.
  • New MultimodalRotaryEmbedding(interleaved_mrope=...) argument.
  • New helper _apply_interleaved_mrope(freqs, mrope_section) that converts the per-channel outer-product layout (3, bs, seq_len, dim) into the interleaved single-tensor layout (bs, seq_len, dim).
  • GPTModel passes config.mrope_interleaved through to the embedding.

Why

MultimodalRotaryEmbedding currently produces the section-based T/H/W cycling used by Qwen2-VL. Qwen3.5-VL (and the matching HuggingFace Qwen3VLTextRotaryEmbedding.apply_interleaved_mrope unified on 2026-02-24) use a different layout where:

  • T freqs occupy stride-3 positions {0, 3, 6, ...}
  • H freqs occupy {1, 4, 7, ...}
  • W freqs occupy {2, 5, 8, ...}

Both layouts are now supported; the flag picks between them.

Risk

mrope_interleaved defaults to False, so existing Qwen2-VL / GPT users see no behavior change.

Test plan

  • Existing MRoPE tests still pass with mrope_interleaved=False (default).
  • New layout matches HF apply_interleaved_mrope reference on a fixed (position_ids, mrope_section) input.

Notes

This is the core prerequisite for an upcoming examples/multimodal_dev/ PR that adds Qwen3.5-VL.

🤖 Generated with Claude Code

@copy-pr-bot

copy-pr-bot Bot commented May 12, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@wplf

wplf commented May 13, 2026

Copy link
Copy Markdown
Member Author

/ok to test d6d1aff

`MultimodalRotaryEmbedding` already supports the original section-based
T/H/W cycling (Qwen2-VL style). Qwen3.5-VL and the HuggingFace
`Qwen3VLTextRotaryEmbedding.apply_interleaved_mrope` use a different
layout where H freqs occupy stride-3 positions {1,4,7,...} and W freqs
occupy {2,5,8,...}, with T at {0,3,6,...}.

Add a new `interleaved_mrope` flag on the embedding (default `False`,
preserves existing behavior) plus a `mrope_interleaved` config field on
`TransformerConfig`, and wire it through `GPTModel`.

Helper `_apply_interleaved_mrope` merges the per-channel outer-product
layout `(3, bs, seq_len, dim)` into the interleaved single-tensor
layout `(bs, seq_len, dim)`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Co-Authored-By: BestJuly <19769279+BestJuly@users.noreply.github.com>
@wplf wplf force-pushed the feat/mrope-interleaved-layout branch from d6d1aff to a9a9b50 Compare May 13, 2026 10:24
@wplf

wplf commented May 13, 2026

Copy link
Copy Markdown
Member Author

/ok to test a9a9b50

@BestJuly BestJuly enabled auto-merge May 14, 2026 01:57
@wplf

wplf commented May 14, 2026

Copy link
Copy Markdown
Member Author

/ok to test 188cc39

@wplf

wplf commented May 15, 2026

Copy link
Copy Markdown
Member Author

/ok to test 235e4a6

@BestJuly BestJuly added this pull request to the merge queue May 15, 2026
@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25907530418

@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25910357667

Merged via the queue into NVIDIA:dev with commit cfbd9df May 15, 2026
178 of 180 checks passed
SpencerGarnets added a commit to ai-blaise/Megatron-LM that referenced this pull request May 16, 2026
Upstream dev tip: 77c0f8c

Pulled commits:

- 77c0f8c [Dev][feat] Support A2A Overlap for Megatron-FSDP (NVIDIA#3796)

- 8195337 [dev] [3/5] Qwen3.5 support: SharedExpertMLP meta init (NVIDIA#4749)

- 2672ff5 [DEV] fix(megatron-fsdp): preserve non-meta tensors during meta materialization (NVIDIA#4155)

- cfbd9df [dev] [4/5] Qwen3.5 support: Interleaved MRoPE layout (NVIDIA#4750)

- df12802 [dev] Fix GDN DTensor splitting for FSDP checkpointing (NVIDIA#4799)

Resolution: zero conflicts; git auto-merged 12 shared files in megatron/core/{distributed,models,pipeline_parallel,transformer} and tests/unit_tests/a2a_overlap. No ai-blaise custom files touched.

Gates:

- git diff --check: clean

- conflict markers: none

- py_compile (16 changed .py files): OK

- indexcache: 27/28 pass; the 1 fail (test_nvfp4_non_blackwell_cuda_uses_reference_fallback) reproduces identically at the pre-merge base SHA (sglang occupies all 8 H200s in EXCLUSIVE_PROCESS mode -> cudaErrorDevicesUnavailable). 1 Blackwell-only test auto-skips on H200.

- transformer gdn/mtp/moe suite: 53 failed / 7 passed / 55 skipped / 5 errors -- IDENTICAL numbers at pre-merge base; all failures are the same environmental cudaErrorDevicesUnavailable.

- 2-rank torchrun layer-wise optimizer smoke: blocked (no free GPUs).

Custom preserved: StreamBP, IndexCache config, NVFP4 indexer (7e78f28), HISA topk1024 backward test (c628c13), pyproject emerging_optimizers v0.2.0 pin, mHC/MTP/MoE composition.
wplf added a commit to wplf/Megatron-LM that referenced this pull request Jun 4, 2026
…onfig entry

Match the merged dev shadow PR NVIDIA#4750 exactly:
- reflow torch.stack/torch.cat in MultimodalRotaryEmbedding.forward
- add 'mrope_interleaved': False to test_hybrid_moe_model config dict

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants