[mimo] Support bridge fan-out for variable modality tokens#5062
Conversation
Signed-off-by: Li Ding <liding@nvidia.com>
|
/ok to test ffe26d5 |
|
/ok to test 6006d40 |
| # Apply grad scaling if needed (for last stage only). | ||
| for module_name in output_tensor.keys(): | ||
| if output_tensor_grad[module_name] is None and config.grad_scale_func is not None: | ||
| output_tensor_grad_module = _unwrap_single_tensor_list(output_tensor_grad[module_name]) |
There was a problem hiding this comment.
May be I'm missing something.. would we need this change in schedules
There was a problem hiding this comment.
Yes. We hit this in the Qwen3.5-27B MIMO validation run: language TP4/PP4/DP2, image encoder TP1/PP1/DP1, with variable_seq_lengths=True.
In that path, get_tensor_shapes() returns [()]. The regular P2P communicator then returns a single-element grad list, so the multimodule dict can contain:
{"language": [grad_tensor]}
backward_step_multimodule() already unwraps this shape for input_tensor, but not for output_tensor_grad. Without this, we pass [grad_tensor] into backward instead of grad_tensor.
So this is just making multimodule backward handle the same single-tensor-list convention as the existing single-module schedule.
Signed-off-by: Li Ding <liding@nvidia.com>
|
/ok to test 8b29dc9 |
|
/ok to test fe27b34 |
|
🔄 Merge queue validation started! You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/26781230312 |
Signed-off-by: Li Ding <liding@nvidia.com>
Summary
Non-colocated bridge fan-out previously sent equal batch-dim slices to every peer. For VLM/AVLM data the per-sample modality-token counts vary, so the LLM received the wrong embeddings. Tag encoder outputs with per-sample split sizes and have the bridge honor them when fanning activations/gradients out.
MimoModel._attach_modality_split_sizesannotates flat encoder outputs with_mimo_bridge_split_sizesderived fromspecial_token_ids.MimoModel._empty_encoder_outputreturns a(0, hidden_size)placeholder when a text-only sample reaches a non-colocated encoder rank;_has_encoder_tokensguards against silent loss of modality data.BridgeCommunicator._split_tensor_at_batch_dimconsumes_mimo_bridge_split_sizes(per-peer or per-sample) with sum/length checks.BridgeCommunicator._communicate_shapesaccepts a per-peer tensor list via the new_as_per_peer_tensorshelper.schedules.backward_step_multimoduleunwraps single-element tensor lists so the per-peer paths feed back cleanly.Tests
CPU-only unit tests added inside existing files:
test_bridge_communicator.py::TestBridgeCommunicatorSplitMetadata—_split_tensor_at_batch_dimsum-mismatch +num_splits=1short-circuit;_as_per_peer_tensorsbroadcast / pass-through / count-mismatch.test_mimo_model.py::TestMimoModelFanoutHelpers—_attach_modality_split_sizeshappy path + skip branches;_has_encoder_tokenspresence/absence.