Skip to content

[mimo] Support bridge fan-out for variable modality tokens#5062

Merged
ko3n1g merged 9 commits into
NVIDIA:mainfrom
liding-nv:mimo-fanout-var-tokens
Jun 1, 2026
Merged

[mimo] Support bridge fan-out for variable modality tokens#5062
ko3n1g merged 9 commits into
NVIDIA:mainfrom
liding-nv:mimo-fanout-var-tokens

Conversation

@liding-nv

Copy link
Copy Markdown
Contributor

Summary

Non-colocated bridge fan-out previously sent equal batch-dim slices to every peer. For VLM/AVLM data the per-sample modality-token counts vary, so the LLM received the wrong embeddings. Tag encoder outputs with per-sample split sizes and have the bridge honor them when fanning activations/gradients out.

  • MimoModel._attach_modality_split_sizes annotates flat encoder outputs with _mimo_bridge_split_sizes derived from special_token_ids.
  • MimoModel._empty_encoder_output returns a (0, hidden_size) placeholder when a text-only sample reaches a non-colocated encoder rank; _has_encoder_tokens guards against silent loss of modality data.
  • BridgeCommunicator._split_tensor_at_batch_dim consumes _mimo_bridge_split_sizes (per-peer or per-sample) with sum/length checks.
  • BridgeCommunicator._communicate_shapes accepts a per-peer tensor list via the new _as_per_peer_tensors helper.
  • schedules.backward_step_multimodule unwraps single-element tensor lists so the per-peer paths feed back cleanly.

Tests

CPU-only unit tests added inside existing files:

  • test_bridge_communicator.py::TestBridgeCommunicatorSplitMetadata_split_tensor_at_batch_dim sum-mismatch + num_splits=1short-circuit; _as_per_peer_tensors broadcast / pass-through / count-mismatch.
  • test_mimo_model.py::TestMimoModelFanoutHelpers_attach_modality_split_sizes happy path + skip branches; _has_encoder_tokens presence/absence.

liding-nv added 2 commits May 29, 2026 08:15
Signed-off-by: Li Ding <liding@nvidia.com>
Signed-off-by: Li Ding <liding@nvidia.com>
@copy-pr-bot

copy-pr-bot Bot commented May 29, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@liding-nv liding-nv requested a review from yashaswikarnati May 29, 2026 15:38
@liding-nv

Copy link
Copy Markdown
Contributor Author

/ok to test ffe26d5

Signed-off-by: Li Ding <liding@nvidia.com>
@liding-nv

Copy link
Copy Markdown
Contributor Author

/ok to test 6006d40

# Apply grad scaling if needed (for last stage only).
for module_name in output_tensor.keys():
if output_tensor_grad[module_name] is None and config.grad_scale_func is not None:
output_tensor_grad_module = _unwrap_single_tensor_list(output_tensor_grad[module_name])

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be I'm missing something.. would we need this change in schedules

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. We hit this in the Qwen3.5-27B MIMO validation run: language TP4/PP4/DP2, image encoder TP1/PP1/DP1, with variable_seq_lengths=True.

In that path, get_tensor_shapes() returns [()]. The regular P2P communicator then returns a single-element grad list, so the multimodule dict can contain:

{"language": [grad_tensor]}

backward_step_multimodule() already unwraps this shape for input_tensor, but not for output_tensor_grad. Without this, we pass [grad_tensor] into backward instead of grad_tensor.

So this is just making multimodule backward handle the same single-tensor-list convention as the existing single-module schedule.

Signed-off-by: Li Ding <liding@nvidia.com>
Signed-off-by: Li Ding <liding@nvidia.com>
@liding-nv

Copy link
Copy Markdown
Contributor Author

/ok to test 8b29dc9

Signed-off-by: Li Ding <liding@nvidia.com>
@liding-nv

Copy link
Copy Markdown
Contributor Author

/ok to test fe27b34

@svcnvidia-nemo-ci svcnvidia-nemo-ci added Approved All necessary approvals have been made and removed Final Review PR is in the "final review" stage labels Jun 1, 2026
@ko3n1g ko3n1g added this pull request to the merge queue Jun 1, 2026
@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/26781230312

Merged via the queue into NVIDIA:main with commit de030fc Jun 1, 2026
83 of 85 checks passed
copy-pr-bot Bot pushed a commit that referenced this pull request Jun 12, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Approved All necessary approvals have been made complexity: medium core_r0.16.0 Cherry-pick label for core_r0.16.0 release branch

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants