support for training qwen3 vl with dist train#2367
Conversation
📝 WalkthroughWalkthroughThis PR adds comprehensive Qwen3VL vision-language model support, including distributed training infrastructure, new vision model components, training step implementations with packed sequence handling, and updates to distributed initialization and process group management. Changes
Sequence Diagram(s)sequenceDiagram
participant Trainer
participant Setup
participant InitDist as Initialize<br/>Distributed
participant VisionModule
participant LanguageModule
participant Communicator
Trainer->>Setup: setup()
Setup->>InitDist: torch_dist_init()
alt use_dist_train enabled
InitDist->>InitDist: Create HyperCommGrid for vision
InitDist->>InitDist: Create HyperCommGrid for language
InitDist->>Communicator: Build MultiModulePipelineCommunicator
InitDist->>Setup: Return (pg_collection, grid_dict, pg_collection_dict)
Setup->>Setup: Assign vision/language add flags based on rank membership
Setup->>Setup: Wire communicator to model config
else use_dist_train disabled
InitDist->>Communicator: Build P2PCommunicator
InitDist->>Setup: Return pg_collection
Setup->>Setup: Set add_encoder/decoder to True
end
Setup->>Trainer: Return configured setup with communicator
Trainer->>Trainer: Initialize model (Qwen3VL)
Note over VisionModule,LanguageModule: Vision & Language modules<br/>only active based on add flags
Trainer->>Trainer: Training loop
Trainer->>VisionModule: forward() with vision inputs
VisionModule-->>Trainer: vision embeddings + deepstack features
Trainer->>LanguageModule: forward() with language inputs + vision embeddings
LanguageModule-->>Trainer: logits
Trainer->>Communicator: Handle inter-module communication if needed
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 17
Note
Due to the large number of review comments, Critical, Major severity comments were prioritized as inline comments.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (7)
src/megatron/bridge/training/vlm_step.py (1)
145-192: 🛠️ Refactor suggestion | 🟠 Major
if True:creates unreachable dead code — clean up before merging.The
elsebranch (lines 159-192) is unreachable. This looks like a debug/WIP artifact. Either remove the dead branch or replaceif True:with the intended condition.Note: the dead
elsebranch also has a latent bug at line 190 whereif attn is not None:is outside theif tokens_or_input is not None:block, makingtarget_lenpotentially undefined — but since it's dead code, this is moot for now.Proposed cleanup
# When using pipeline parallelism, ensure fixed shapes equal to cfg.model.seq_length - if True: - seq_len = cfg.model.seq_length + seq_len = cfg.model.seq_length - tokens_or_input = batch.get("tokens") if batch.get("tokens") is not None else batch.get("input_ids") - tokens_or_input = pad_or_truncate_2d_to_len(tokens_or_input, seq_len, seq_len, pad_value=0) - ... - else: - # No PP: pad sequence length to nearest multiple of 128 for efficiency (capped at model seq_length) - ... + tokens_or_input = batch.get("tokens") if batch.get("tokens") is not None else batch.get("input_ids") + tokens_or_input = pad_or_truncate_2d_to_len(tokens_or_input, seq_len, seq_len, pad_value=0) + ...As per coding guidelines: "If code is commented out, include a comment describing its usage and why it is commented out; otherwise remove it as debug code before merging."
src/megatron/bridge/training/initialize.py (2)
134-214:⚠️ Potential issue | 🟠 Majortorch_dist_init “triple return” contract is inconsistent across branches (lazy_init / skip_mpu_initialization).
torch_dist_init()is annotated to return(ProcessGroupCollection, grid_dict, pg_collection_dict), but it can return:
(None, None, None)whenskip_mpu_initialization=True, and(finish_mpu_init, None, None)whendist_config.lazy_init=True(callable in slot 0),
andfinish_mpu_init()is annotated as returningProcessGroupCollectionbut returns the full 3-tuple.This mismatch makes it very easy for downstream code to treat the first element as a
ProcessGroupCollectionand crash (especially after the new destructuring pattern in setup code).Consider making the return type explicit and self-describing (e.g., a small dataclass/NamedTuple with
pg_collection,grid_dict,pg_collection_dict, and optionalfinish_mpu_init), or at minimum fix the type hints + docstrings so call sites can branch safely oncallable(pg_collection)andpg_collection is None.As per coding guidelines, "Use type hints for function arguments and return types" and "Use 'T | None' for nullable types instead of 'Optional[T]'".
363-375:⚠️ Potential issue | 🟡 MinorFix implicit Optional type hints and use built-in tuple generic (ruff RUF013).
Parameters
world_sizeandrank_offsetuse implicit Optional syntax and should use union types. Return type should use the built-intuplegeneric instead ofTuplefrom typing.Proposed diff
def _create_pg_collection( model_config: GPTModelProvider | T5ModelProvider, num_distributed_optimizer_instances: int, get_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]] = None, get_position_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]] = None, - world_size: int = None, - rank_offset: int = None, -) -> Tuple[ProcessGroupCollection, HyperCommGrid]: + world_size: int | None = None, + rank_offset: int | None = None, +) -> tuple[ProcessGroupCollection, HyperCommGrid]:src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/transformer_config.py (1)
15-119:⚠️ Potential issue | 🟠 MajorRemove unused
torchimport, fix assertion error handling, and avoid mutatinghf_configin-place.Issues to address:
import torchis unused (line 20 shows it's imported but never referenced; onlytorch.nn.functionalis used).assert config.vision_model_type is None, ValueError(...)is incorrect syntax—it raisesAssertionErrorwith aValueErrorobject as the message, and assertions can be disabled at runtime. Useraise NotImplementedError(...)instead to fail fast with a clear error.- Mutating
hf_configin-place (depth, hidden_size, num_heads, etc.) couples the function to the input object and risks unintended side effects ifvision_transformer_configis reused elsewhere. Store these values in local variables or construct a copy of the config to avoid mutation.Proposed diff
-import torch import torch.nn.functional as F @@ def get_vision_model_config(config: Qwen3VLTransformerConfig, hf_config): + """Populate a Qwen3VLTransformerConfig instance with vision-model settings. + + Note: This function mutates and returns `config`. + """ config.num_moe_experts = None config.expert_model_parallel_size = 1 config.moe_ffn_hidden_size = None @@ if config.vision_model_type == "vit_2b": - hf_config.depth = 45 - hf_config.hidden_size = 1536 - hf_config.num_heads = 16 - hf_config.intermediate_size = 8960 - hf_config.patch_size = 16 - hf_config.spatial_merge_size = 2 - if hasattr(hf_config, "head_dim"): - hf_config.head_dim = 96 + hf_depth = 45 + hf_hidden_size = 1536 + hf_num_heads = 16 + hf_intermediate_size = 8960 + hf_patch_size = 16 + hf_spatial_merge_size = 2 + hf_head_dim = 96 else: - assert config.vision_model_type is None, ValueError(f"support only vit_2b, but got {config.vision_model_type}") + raise NotImplementedError(f"Only vision_model_type='vit_2b' is supported, got: {config.vision_model_type}") @@ - config.num_layers = hf_config.depth - config.ffn_hidden_size = hf_config.intermediate_size - config.num_attention_heads = hf_config.num_heads # num_heads + config.num_layers = hf_depth + config.ffn_hidden_size = hf_intermediate_size + config.num_attention_heads = hf_num_heads # num_heads @@ - config.hidden_size = hf_config.hidden_size # hidden_size + config.hidden_size = hf_hidden_size # hidden_size @@ - config.patch_size = hf_config.patch_size + config.patch_size = hf_patch_size @@ - config.spatial_merge_size = hf_config.spatial_merge_size + config.spatial_merge_size = hf_spatial_merge_sizesrc/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/rope.py (1)
30-126:⚠️ Potential issue | 🟠 MajorRegister
inv_freqas a buffer and fix type hints to use|instead ofOptional.Two issues need fixing:
inv_freqmust be a buffer: Currently stored as a plain tensor attribute (line 65-67), it won't track device moves whenmodule.to(device)is called and won't serialize/deserialize properly with state_dict. The codebase already uses this pattern (seeutils.pyline 84). Additionally, specifyingdevice=torch.cuda.current_device()at initialization is problematic—the tensor should be created on the default device and follow module movements.Type hints: Per coding guidelines, use
T | Noneinstead ofOptional[T]. Line 52 and 54 should use the modern union syntax.Proposed diff
def __init__( self, kv_channels: int, rotary_percent: float, rotary_interleaved: bool = False, - seq_len_interpolation_factor: Optional[float] = None, + seq_len_interpolation_factor: float | None = None, rotary_base: int = 10000, - cp_group: torch.distributed.ProcessGroup = None, + cp_group: torch.distributed.ProcessGroup | None = None, ) -> None: super().__init__() dim = kv_channels if rotary_percent < 1.0: dim = int(dim * rotary_percent) self.rotary_interleaved = rotary_interleaved assert not self.rotary_interleaved, "only support qwen3vl" self.seq_len_interpolation_factor = seq_len_interpolation_factor - self.inv_freq = 1.0 / ( - rotary_base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device()) / dim) - ) + inv_freq = 1.0 / (rotary_base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) self.is_thd_format = False self.cp_group = cp_groupsrc/megatron/bridge/training/setup.py (1)
52-167:⚠️ Potential issue | 🔴 Criticalsetup(): handle lazy-init and skip-mpu-init returns before accessing
pg_collection.ppWhen
lazy_init=True,initialize_megatronreturns a callable as the first element of the tuple; whenskip_mpu_initialization=True, it returnsNone. Directly accessingpg_collection.ppon line 167 will crash in both cases withAttributeErrororTypeError.Add type checking after unpacking:
- If the first element is callable, invoke it to get the real
(pg_collection, grid_dict, pg_collection_dict)- If
pg_collectionis stillNone, raise a clear errorProposed fix
pg_collection, grid_dict, pg_collection_dict = initialize_megatron( cfg=cfg, get_embedding_ranks=get_embedding_ranks, get_position_embedding_ranks=get_position_embedding_ranks, restart_store=restart_store, ) + if callable(pg_collection): + pg_collection, grid_dict, pg_collection_dict = pg_collection() + + if pg_collection is None: + raise RuntimeError("initialize_megatron did not return a ProcessGroupCollection (pg_collection=None)") + if hasattr(cfg.model, "use_dist_train") and cfg.model.use_dist_train:tests/unit_tests/models/qwen_vl/modelling_qwen3_vl/test_model.py (1)
314-323:⚠️ Potential issue | 🟠 MajorPass
pg_collectionto allQwen3VLModelinstances.
Qwen3VLModelnow dereferencespg_collection.*; the constructors formodel_no_decoderandmodel_no_prewill raise whenpg_collectionis omitted. Reuse the collection already created in the test.🛠️ Proposed fix
model_no_decoder = Qwen3VLModel( vision_transformer_config=vision_transformer_config, language_transformer_config=language_transformer_config, language_transformer_layer_spec=language_model_layer_spec, parallel_output=True, pre_process=True, post_process=True, add_encoder=True, add_decoder=False, + pg_collection=pg_collection, ) @@ model_no_pre = Qwen3VLModel( vision_transformer_config=vision_transformer_config, language_transformer_config=language_transformer_config, language_transformer_layer_spec=language_model_layer_spec, parallel_output=True, pre_process=False, post_process=True, add_encoder=True, add_decoder=True, + pg_collection=pg_collection, )Also applies to: 366-375
🤖 Fix all issues with AI agents
In `@examples/recipes/decentralized_pg/pretrain_qwen3_vl_simple.py`:
- Line 37: Replace the import of forward_step from the generic vlm_step with the
Qwen3-VL specific implementation: change the import to use
megatron.bridge.training.qwen3vl_step so the script uses forward_step
implemented in qwen3vl_step; this ensures the Qwen3-VL assertions, data-format
handling (bshd vs thd), position_ids logic, packed sequence handling, and
multimodal input injection (pixel_values and image_grid_thw) are applied instead
of the generic vlm_step behavior.
In `@examples/recipes/qwen_vl/finetune_qwen_vl.py`:
- Around line 106-108: The file unconditionally imports forward_step from
megatron.bridge.training.qwen3vl_step which asserts the model is Qwen3-VL and
breaks Qwen2.5-VL runs; move that import into main() after you determine
recipe/model_family and conditionally import forward_step from the correct
module (if model_family == "Qwen3-VL" import from
megatron.bridge.training.qwen3vl_step else import from
megatron.bridge.training.vlm_step) so the assertion isn’t triggered at import
time, and remove the unused from functools import partial import which is
misplaced after first-party imports.
In `@src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py`:
- Around line 109-114: Guard against pg_collection being None before
dereferencing: check if pg_collection is not None before accessing
pg_collection.cp/tp/pp and assigning self.pg_collection, self.cp_group,
self.tp_group, self.pp_group; if pg_collection is None set those group
attributes to None (or appropriate defaults) and adjust the assert (or replace
it with an explicit check) so you only call hasattr(self.pg_collection, "embd")
when self.pg_collection is not None. Ensure you update any code paths that
assume these groups exist to handle the None/default case.
In `@src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/text_model.py`:
- Around line 84-92: The constructor dereferences pg_collection.cp when creating
self.rotary_pos_emb (Qwen3VLMultimodalRotaryEmbedding) but pg_collection may be
None; guard this by checking/initializing pg_collection before use or asserting
non-None: e.g., if pg_collection is None create a default with the expected .cp
attribute (or raise a clear error), then pass pg_collection.cp into
Qwen3VLMultimodalRotaryEmbedding; ensure the change touches the place where
rotary_pos_emb is assigned and uses pg_collection only after the guard.
In `@src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/transformer_block.py`:
- Around line 86-95: The deepstack_merger_list constructs
Qwen3VLVisionPatchMerger instances without passing the block's tensor-parallel
group, causing linear layers (linear_fc1/linear_fc2) inside those mergers to use
the wrong TP group; update the Qwen3VLVisionPatchMerger construction inside
deepstack_merger_list to include tp_group=self.tp_group (same pattern used in
vision_model where mergers pass tp_group=self.tp_group) so each merger uses the
block's TP group for its linear layers, ensuring consistent tensor-parallel
behavior across config.deepstack_visual_indexes.
- Around line 429-433: The call to sharded_state_dict_default inside
Qwen3VLVisionTransformerBlock's sharded_state_dict loop will raise NameError and
the identity check uses non-PEP8 syntax; fix by importing or providing a
fallback implementation of sharded_state_dict_default (used by
sharded_state_dict when iterating named_children) and replace "not module is
self.layers" / "not module is self.deepstack_merger_list" with "module is not
self.layers" / "module is not self.deepstack_merger_list" respectively; ensure
the imported or fallback sharded_state_dict_default signature matches (module,
prefix, sharded_offsets, metadata) so sharded_state_dict can call it safely.
- Around line 363-435: The method TransformerBlock.sharded_state_dict calls an
undefined helper sharded_state_dict_default and also bypasses parent logic; fix
by first invoking super().sharded_state_dict(prefix, sharded_offsets, metadata)
and merging its result, replace the undefined call by importing/using the
correct helper function (or the module-level utility that actually exists in the
codebase) instead of sharded_state_dict_default when iterating named_children
(reference the symbol sharded_state_dict_default and ensure it matches the real
exported name), and keep using self.pp_group.rank() but verify pp_group
implements rank() where assigned; in short: call super(), swap the undefined
helper for the correct imported helper, and merge results rather than fully
overriding parent behavior.
In `@src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/utils.py`:
- Around line 366-372: In split_data_cp_rank, move the None check for the input
tensor before any attribute access so val is returned early if None (i.e., check
`if val is None: return val` before using `val.shape`), and update the cp_rank
annotation from `cp_rank: int = None` to use an explicit Optional type
(`cp_rank: Optional[int] = None`) to comply with PEP 484; keep the rest of the
logic (cp_size assert and cp_rank defaulting via
mpu.get_context_parallel_rank()) unchanged.
- Around line 349-350: Replace the unreachable assert with a real exception: in
the else branch that currently does "assert False, f'should not have
{token_id=}'", raise a ValueError instead (e.g. "raise ValueError(f'should not
have token_id={token_id}')") so the error is not stripped by python -O and
clearly reports the invalid token_id; update the branch where token_id is
handled accordingly.
In `@src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/vision_model.py`:
- Around line 42-50: The constructor currently defaults pg_collection to None
but accesses self.pg_collection.tp in __init__, which can raise AttributeError;
update the __init__ of the class to validate pg_collection (the pg_collection
parameter and self.pg_collection) before accessing .tp — e.g., if pg_collection
is None raise a clear ValueError stating "pg_collection is required" (or
alternatively assign self.tp_group = None when pg_collection is None), then set
self.pg_collection = pg_collection and self.tp_group = self.pg_collection.tp
only after the null check so the attribute access is safe.
In `@src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py`:
- Around line 146-157: Qwen3VLModelProvider (and Qwen3ModelProvider) reference
attributes add_encoder and add_decoder when constructing Qwen3VLModel but those
attributes are never defined; fix by adding dataclass fields add_encoder: bool =
False and add_decoder: bool = False to the provider class (or its parent) so
they're initialized, or alternatively pass explicit boolean literals into the
Qwen3VLModel constructor where add_encoder/add_decoder are currently used
(mimicking nemotron_vl_provider.py); update both the Qwen3VLModelProvider and
the MoE variant locations that reference add_encoder/add_decoder to use the new
fields or hardcoded values.
In `@src/megatron/bridge/training/qwen3vl_step.py`:
- Around line 136-140: The code currently calls
batch.get("visual_inputs").normalized_for_model() without checking for None;
update the block around where "visual_inputs" is handled so you first assign
visual = batch.get("visual_inputs") and only call visual.normalized_for_model()
if visual is not None, otherwise set multi_modal_inputs to None (or skip
creating it) so downstream uses of multi_modal_inputs are guarded; modify the
logic in qwen3vl_step.py where multi_modal_inputs is created to reference the
new visual variable and handle the None case safely.
In `@src/megatron/bridge/training/train.py`:
- Around line 739-740: Guard the last-stage loss reduction by checking whether
cfg.model.p2p_communicator is not None before accessing .is_pp_last_stage; if it
is None (PP size 1), call is_pp_last_stage(pg_collection.pp) instead. Update the
conditional that currently reads cfg.model.p2p_communicator.is_pp_last_stage to
first test cfg.model.p2p_communicator and then fall back to
is_pp_last_stage(pg_collection.pp) so the loss averaging path only runs when the
correct determination of last stage succeeds.
In `@src/megatron/bridge/training/utils/packed_seq_utils.py`:
- Around line 73-74: Move the line "from megatron.core import mpu" up into the
module import block with the other imports (remove the "// Copied from ..."
comment entirely), and ensure the import is colocated with the top-level imports
used by this file; then refactor to avoid duplication by either importing the
existing implementations of preprocess_packed_seqs and postprocess_packed_seqs
from the shared implementation in the qwen_vl utils module or extract a new
shared utility module and have both locations import from it (preserve the
use_fp8_padding behavior when consolidating).
- Around line 130-157: The loop unconditionally uses attention_mask[i] causing a
crash when attention_mask is None; update the loop in the pre_process branch to
guard use of attention_mask: compute d = input_ids[i, attention_mask[i]] if
attention_mask is not None, otherwise set d = input_ids[i, :seqlen] for the
cp_size <= 1 branch (use seqlen from seqlens_in_batch_cpu[i]) and d =
input_ids[i, :seqlen_padded_i] for the cp_size > 1 branch (use seqlen_padded_i
from seqlens_in_batch_padded_cpu[i]) before the subsequent slicing that writes
into input_ids_rmpad so indexing never tries to subscript None; adjust
references around input_ids_rmpad, seqlens_in_batch_cpu,
seqlens_in_batch_padded_cpu, cu_seqlens_padded_cpu, cp_size and cp_rank.
🟡 Minor comments (12)
src/megatron/bridge/training/vlm_step.py-124-125 (1)
124-125:⚠️ Potential issue | 🟡 MinorHardcoded
is_first = True/is_last = Truebypasses pipeline-stage detection.This forces all ranks to load labels/loss_mask and visual inputs, regardless of their actual PP stage. If this is intentional for the multi-module VLM distributed training path, please add a comment explaining why pipeline-stage gating was removed (and whether the original
is_pp_first_stage/is_pp_last_stageimports on line 21 are now dead).examples/recipes/decentralized_pg/pretrain_qwen3_vl_simple.py-41-43 (1)
41-43:⚠️ Potential issue | 🟡 MinorMisleading comment: "Qwen3 4B" vs actual 30B config.
Line 42 says "Get the standard Qwen3 4B pretrain config" but line 43 uses
qwen3_vl_30b_a3b_pretrain_config.Fix
- # Get the standard Qwen3 4B pretrain config with overrides + # Get the standard Qwen3-VL 30B MoE pretrain config with overridessrc/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/rope.py-265-307 (1)
265-307:⚠️ Potential issue | 🟡 MinorSilence unused
cu_seqlenswarning and make RoPE-fusion failure explicit.
apply_rotary_pos_emb_thd_absolute(..., cu_seqlens, ...)doesn’t usecu_seqlens(ruff ARG001). Rename to_cu_seqlens(or add a targeted noqa) to keep the interface but avoid lint noise.assert not config.apply_rope_fusionwill be stripped underpython -O; prefer a real exception (NotImplementedError/ValueError) if this is a hard constraint.As per coding guidelines, "When a feature is not supported (such as audio embeddings), raise an explicit error (e.g., NotImplementedError) instead of silently ignoring the input to fail fast with a clear message." Based on learnings: "when a feature is not supported ... raise an explicit error (e.g., NotImplementedError) instead of silently ignoring".
src/megatron/bridge/training/qwen3vl_step.py-41-47 (1)
41-47:⚠️ Potential issue | 🟡 MinorSilence unused-argument warnings and remove the unused variable.
Ruff flags ARG001/F841 in this function.🛠️ Proposed fix
batch = next(data_iterator) + _ = use_mtp + _ = is_first_pp_stage + _ = is_last_pp_stage @@ - max_seqlen_in_batch = seqlens_in_batch.max().item()Also applies to: 224-224
src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py-16-18 (1)
16-18:⚠️ Potential issue | 🟡 MinorAlign type hints with repo conventions and drop the unused import.
Use built-in generics/union types and remove the unusedmpuimport to satisfy lint and style rules.As per coding guidelines, Use built-in generics (list, dict, tuple) instead of typing equivalents, and use 'T | None' for nullable types instead of 'Optional[T]'.🛠️ Proposed fix
-from typing import List, Dict + @@ -from megatron.core import InferenceParams, mpu, tensor_parallel +from megatron.core import InferenceParams, tensor_parallel @@ - def set_input_tensor(self, input_tensor: List[Dict[str, torch.Tensor]]): + def set_input_tensor(self, input_tensor: list[dict[str, torch.Tensor]]): @@ - cp_img_num: list[int] = None, - images_padded: list[bool] = None, + cp_img_num: list[int] | None = None, + images_padded: list[bool] | None = None,Also applies to: 195-200, 300-301
src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/attention.py-40-41 (1)
40-41:⚠️ Potential issue | 🟡 MinorFix lint issues: unused parameter and f-string without placeholders.
Ruff flags the unusedrotary_pos_cos_sinargument and thefprefix on a static string.🛠️ Proposed fix
- rotary_pos_cos_sin: Optional[Tensor] = None, + _rotary_pos_cos_sin: Optional[Tensor] = None, @@ - raise ValueError(f"CUDA graphs must use flash decode with static batching!") + raise ValueError("CUDA graphs must use flash decode with static batching!")Also applies to: 126-127
src/megatron/bridge/training/train.py-39-40 (1)
39-40:⚠️ Potential issue | 🟡 MinorRemove unused
MultiModulePipelineCommunicatorimport (F401).
Lint currently flags this as unused.🛠️ Proposed fix
-from megatron.core.pipeline_parallel.multimodule_communicator import MultiModulePipelineCommunicatorsrc/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py-152-155 (1)
152-155:⚠️ Potential issue | 🟡 MinorUse
print_rank_0for model logging.
This avoids duplicate logs across ranks during distributed training.As per coding guidelines, Use 'print_rank_0' for logging in model bridge to avoid duplicate output across ranks.🛠️ Proposed fix
-from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig as Qwen3VLConfigHF +from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig as Qwen3VLConfigHF +from megatron.bridge.utils.common_utils import print_rank_0 @@ - print(f"rank {torch.distributed.get_rank()} use hf vision model") + print_rank_0(f"rank {torch.distributed.get_rank()} use hf vision model")src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/utils.py-242-242 (1)
242-242:⚠️ Potential issue | 🟡 MinorTypo: "flaaten" → "flatten".
- assert input_ids.dim() == 1, "input_ids should be flaaten" + assert input_ids.dim() == 1, "input_ids should be flattened"src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/vision_model.py-1-24 (1)
1-24:⚠️ Potential issue | 🟡 MinorMissing NVIDIA copyright header and unused imports.
This file is missing the required NVIDIA copyright header. Additionally, static analysis flags several unused imports:
dataclass,Union,MultimodalProjector,MegatronModule,build_module, andget_tensor_model_parallel_group_if_none.Proposed fix
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + -from dataclasses import dataclass -from typing import Optional, Union +from __future__ import annotations import torch from megatron.core import InferenceParams from megatron.core.models.common.vision_module.vision_module import VisionModule -from megatron.core.models.vision.multimodal_projector import MultimodalProjector from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.enums import ModelType -from megatron.core.transformer.module import MegatronModule -from megatron.core.transformer.spec_utils import ModuleSpec, build_module -from megatron.core.utils import get_tensor_model_parallel_group_if_none +from megatron.core.transformer.spec_utils import ModuleSpec from torch import nn from torch.nn import functional as FAs per coding guidelines: "Add NVIDIA copyright header to all Python files" and "Use
T | Nonefor nullable types instead ofOptional[T]".src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/vision_model.py-197-245 (1)
197-245:⚠️ Potential issue | 🟡 MinorDocstring args/return mismatch with actual signature and return value.
- Docstring refers to parameter
x(Line 208) but the actual parameter ishidden_states.- Docstring mentions
packed_seq_params(Line 210) which is not in the function signature (it's computed internally on Line 235).- The return type annotation says
torch.Tensor(Line 203) but the function returns a tuple(hidden_states, deepstack_feature_lists)on Line 245.Proposed fix
def forward( self, hidden_states: Optional[torch.Tensor], grid_thw: torch.Tensor, inference_params: Optional[InferenceParams] = None, extra_block_kwargs: dict = None, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, list]: """Forward function of the Qwen3 Vision Model. This function passes the input tensors through the embedding layer and then the transformer. Args: - x (torch.Tensor): input image/video data of shape [n_tokens, n_dims] + hidden_states (torch.Tensor): input image/video data of shape [n_tokens, n_dims] grid_thw (torch.Tensor): the size tensor indicates grid size of each image/frame - packed_seq_params (PackedSeqParams): parameters to build attention mask in the backend + inference_params (InferenceParams, optional): inference parameters + extra_block_kwargs (dict, optional): additional keyword arguments for the decoder block Returns: - x (torch.Tensor): output after final transformer block of shape [b, s, h]. + tuple: (hidden_states, deepstack_feature_lists) """src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/utils.py-666-679 (1)
666-679:⚠️ Potential issue | 🟡 Minor
cu_seqlens_paddedis passed undivided toPackedSeqParams, but this is currently a latent issue since the returnedpacked_seq_paramsis never used.While the code structure does create a mismatch when
cp_size > 1(buffer is sized assum(seqlens_in_batch_padded_cpu) // cp_sizebutcu_seqlens_paddedreflects full, undivided cumulative lengths), this is not a critical runtime bug in the current codebase. All calls topreprocess_packed_seqs()discard the returnedpacked_seq_paramswith_(lines 437, 456, 467, 523 in model.py), and the parameter passed to the language model is alwaysNone(line 536).The data itself is correctly divided during packing (indices are adjusted:
start_idx = cu_seqlens_padded_cpu[i] // cp_sizeon line 650). However, ifpacked_seq_paramswere ever used downstream with CP enabled, the mismatch would cause incorrect indexing. Consider either: (1) adjustcu_seqlens_paddedandmax_seqlen_in_batchfor the per-rank view whencp_size > 1, or (2) clarify whether this function is intended to support CP splitting and document the limitation.
🧹 Nitpick comments (17)
scripts/performance/argument_parser.py (1)
148-148: Consider whetherqwen3vlbelongs as a--domainchoice or should be handled via--model_family_name.The existing domain values (
llm,vlm) are generic categories, whereasqwen3vlis model-specific. This sets a precedent where each new VLM variant could require its own domain entry. If the Qwen3-VL training path diverges significantly enough from the generalvlmpath to justify a separate domain, this is fine — but if the differences are minor, routing through--model_family_name(or a sub-option) would keep the domain list stable.examples/recipes/qwen_vl/finetune_qwen_vl.py (1)
93-93: Use built-intupleinstead oftyping.Tuplefor Python 3.10+.Line 118 uses
Tuple[argparse.Namespace, list[str]], mixingtyping.Tuplewith built-inlist. Per coding guidelines, prefer built-in generics.Proposed fix
-from typing import TupleAnd on line 118:
-def parse_cli_args() -> Tuple[argparse.Namespace, list[str]]: +def parse_cli_args() -> tuple[argparse.Namespace, list[str]]:As per coding guidelines: "Use built-in generics (list, dict, tuple) instead of typing equivalents".
src/megatron/bridge/training/utils/packed_seq_utils.py (2)
77-78: Missing type hint foruse_fp8_paddingparameter.Fix
-def preprocess_packed_seqs( - input_ids: torch.Tensor, attention_mask: torch.Tensor, pre_process: bool = True, use_fp8_padding=False -) -> tuple[torch.Tensor, PackedSeqParams]: +def preprocess_packed_seqs( + input_ids: torch.Tensor, attention_mask: torch.Tensor | None, pre_process: bool = True, use_fp8_padding: bool = False +) -> tuple[torch.Tensor, PackedSeqParams]:As per coding guidelines: "Use type hints for function arguments and return types" and "Use 'T | None' for nullable types instead of 'Optional[T]'".
196-197: Minor: Use unpacking instead of list concatenation.Per Ruff RUF005, prefer
[batch_size, seq_len, *list(output.shape[2:])].Fix
- shape = [batch_size, seq_len] + list(output.shape[2:]) # 1,packed, dim -> batch_size, seq_len, dim + shape = [batch_size, seq_len, *output.shape[2:]] # 1,packed, dim -> batch_size, seq_len, dimsrc/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py (2)
39-39: Remove unused importget_vision_model_config.Flake8 confirms
get_vision_model_configis imported but never used in this file.Fix
-from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.transformer_config import get_vision_model_config
139-157: Vision spec construction is duplicated between dense and MoE providers.Lines 139-144 (dense provider) and lines 307-312 (MoE provider) construct identical
vision_transformer_layer_specandvision_patch_merger_specobjects. Consider extracting a shared helper method (e.g., on a common base class or as a module-level function) to keep these in sync.src/megatron/bridge/training/initialize.py (1)
589-648: Dist-train PG split: cache rank-membership checks + validate world-size partitioning.This block calls
is_rank_in_pg(...)multiple times for the same collections; please cache results to avoid repeatedget_process_group_ranks()calls during init.Also recommend adding a sanity check that
model_config.vision_world_size + model_config.language_world_size == torch.distributed.get_world_size()(orget_world_size_safe()), otherwise you can end up with “neither in vision nor language” ranks that pass the current logic until a later failure.As per coding guidelines, "Follow the existing code style and conventions as documented in CODING_GUIDELINES.md".
src/megatron/bridge/training/config.py (1)
1318-1413: Add invariant checks foruse_dist_trainso dp-size computation can’t silently go wrong.When
use_dist_train=True, dp-size is derived frommodel.language_world_size. That’s fine ifmodel.tensor_model_parallel_size,pipeline_model_parallel_size, andcontext_parallel_sizeare also the language-side values at this point. If those fields still reflect a “combined” or vision-side topology, dp-size will be incorrect.Suggestion: add a small validation in
set_data_parallel_size()(ormodel.finalize()) thatlanguage_world_size % (tp*pp*cp) == 0with the exact values used here, and raise a clear error if not.As per coding guidelines, "Be explicit about required vs optional fields in configuration objects; do not add arbitrary defaults".
src/megatron/bridge/models/gpt_provider.py (1)
200-210: UseT | Nonefor the new nullable config fields.
Align new fields with the repo’s nullable type-hint convention.As per coding guidelines, Use 'T | None' for nullable types instead of 'Optional[T]'.🛠️ Proposed fix
- vision_model_type: Optional[str] = None + vision_model_type: str | None = None @@ - dist_train_vision_chunk_size: Optional[int] = 1 - vision_world_size: Optional[int] = None - language_world_size: Optional[int] = None - vision_tensor_model_parallel_size: Optional[int] = None - vision_pipeline_model_parallel_size: Optional[int] = None - vision_context_parallel_size: Optional[int] = None - vision_expert_tensor_parallel_size: Optional[int] = None - vision_expert_model_parallel_size: Optional[int] = None + dist_train_vision_chunk_size: int | None = 1 + vision_world_size: int | None = None + language_world_size: int | None = None + vision_tensor_model_parallel_size: int | None = None + vision_pipeline_model_parallel_size: int | None = None + vision_context_parallel_size: int | None = None + vision_expert_tensor_parallel_size: int | None = None + vision_expert_model_parallel_size: int | None = Nonesrc/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/utils.py (4)
16-17: Use modern union type syntax per coding guidelines.The coding guidelines require
T | Noneinstead ofOptional[T]andX | Yinstead ofUnion[X, Y](Python 3.10+). These typing imports are used throughout the file (lines 92–94, 103, etc.).Suggested fix
-from typing import Optional, Union +from __future__ import annotationsThen replace all
Optional[X]withX | NoneandUnion[X, Y]withX | Ythroughout the file.
569-596:AllGatherVisionEmbeddings.backward— integerstart_idxwhencp_rank == 0may cause slicing issues.On Line 593, when
cp_rank == 0,torch.cat(seqlens_on_cp_ranks[:0])produces an empty tensor, and.sum()returns a tensor (scalar0), not a Pythonint. Then on Line 595,grad_output[start_idx:end_idx]uses tensor indices. While this works in practice, mixing tensor and int index types is fragile. More importantly, thectx.save_for_backwardstores allseqlens_on_cp_rankstensors. If these are large lists, this is fine, but verify the tensors are detached and on the correct device.Safer start_idx computation
- start_idx = torch.cat(seqlens_on_cp_ranks[:cp_rank]).sum() if cp_rank != 0 else 0 + start_idx = int(torch.cat(seqlens_on_cp_ranks[:cp_rank]).sum().item()) if cp_rank != 0 else 0
700-701:postprocess_packed_seqs: redundant.cpu()call may trigger an extra D2H sync.
attention_mask.sum(dim=1, dtype=torch.int32).cpu().tolist()on Line 701 explicitly moves to CPU. Ifattention_maskis already on GPU,.tolist()alone will perform the D2H transfer. The extra.cpu()call is harmless but redundant. More importantly, this is a synchronization point — consider batching it with other D2H transfers above if performance matters here.
290-292: Magic default token IDs should reference the config, not be hardcoded.
image_token_id: int = 151655andvideo_token_id: int = 151656are hardcoded defaults inreorganize_inputs. These IDs are model-specific and already defined inQwen3VLTransformerConfig. Hardcoding them here risks silent mismatch if the config changes.src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/vision_model.py (2)
134-195:fast_pos_embed_interpolate: potential performance concern with.tolist()in inner loop.Lines 170–171 call
.tolist()and.extend()inside the per-image loop, converting GPU tensors to Python lists element by element. For a large number of images/grids, this creates many small D2H synchronizations and Python list operations. Consider accumulating the index and weight tensors on GPU and converting once after the loop.Also, the unused loop variable
ton Line 140 (flagged by Ruff B007) can be replaced with_to signal intent.
239-244:torch.splitfollowed bytorch.catis a no-op.Lines 242–244 split
hidden_statesinto chunks and immediately concatenate them back. Unless there's a side effect or a planned insertion between split and cat, this can be removed.- split_sizes = (grid_thw.prod(-1) // self.spatial_merge_size**2).tolist() - hidden_states = torch.split(hidden_states, split_sizes) - hidden_states = torch.cat(hidden_states, dim=0)If this is a placeholder for future logic, add a comment explaining it.
src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/transformer_block.py (2)
64-64: Remove or document commented-out code.Lines 64 and 76 contain commented-out
model_comm_pgsparameters without any explanation. As per coding guidelines, commented-out code should include a comment describing its purpose and why it's commented out, or be removed before merging.- # model_comm_pgs: ModelCommProcessGroups = None, ... - # model_comm_pgs=model_comm_pgs,Also applies to: 76-76
336-339: O(n) index lookupself.deepstack_visual_indexes.index(l_no)on every layer forward pass.Both in the checkpointed path (Line 134) and the non-checkpointed path (Line 337),
.index(l_no)performs a linear scan ofdeepstack_visual_indexesfor every matching layer. Consider converting to adictmappinglayer_no → deepstack_idxin__init__for O(1) lookup.Proposed optimization in __init__
self.deepstack_visual_indexes = config.deepstack_visual_indexes + self._deepstack_index_map = { + l_no: idx for idx, l_no in enumerate(config.deepstack_visual_indexes) + }Then use
self._deepstack_index_map[l_no]instead ofself.deepstack_visual_indexes.index(l_no).
3382cce to
2f55a42
Compare
|
Is there anyone can help to review this pr? |
maybe @yaoyu-33 ? |
fbe1dec to
e718101
Compare
|
/ok to test |
@shifangx, there was an error processing your request: See the following link for more information: https://docs.gha-runners.nvidia.com/cpr/e/1/ |
e9c4b6b to
dd94a1d
Compare
8f55085 to
8a4d000
Compare
|
/ok to test 8a4d000 |
|
@tomlifu can you look at this PR since you are also working on qwen vl? |
|
/ok to test e010b5f |
e010b5f to
b231ca2
Compare
|
/ok to test b231ca2 |
b231ca2 to
6ddef72
Compare
|
/ok to test 6ddef72 |
6ddef72 to
b7293f6
Compare
|
/ok to test b7293f6 |
b7293f6 to
5bfd35f
Compare
|
/ok to test 5bfd35f |
…init, rename input setter
5bfd35f to
e3613f2
Compare
e3613f2 to
d8d3fbd
Compare
|
/ok to test ac0dcf0 |
Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
What does this PR do ?
Add a one line overview of what this PR aims to accomplish.
This pr adds DistTrain support for Qwen3-VL in Megatron-Bridge.
Vision and language sub modules get separate process groups and communication grids, wired through a MultiModule pipeline communicator.
The following script is an example to training a proxy model with dist_train.
Changelog
GitHub Actions CI
See the CI sectionin the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.
Before your PR is "Ready for review"
Pre checks:
If you haven't finished some of the above items you can still open "Draft" PR.
Additional Information
Summary by CodeRabbit
Release Notes
New Features
Updates
Bug Fixes