Skip to content

[bug] Qwen3.5 dense small models (2B/4B) crash due to shadow embedding breaking tied embeddings #3112

@jQizhang

Description

@jQizhang

Problem

Summary

Qwen3.5 dense models with tie_word_embeddings=True (e.g. Qwen3.5-4B) crash during forward pass:

AttributeError: 'function' object has no attribute 'word_embeddings'

Models with tie_word_embeddings=False (e.g. Qwen3.5-9B) work fine.

Root Cause

In Qwen3VLGPTModel.forward() (text_model.py:183-192), when MTP + sequence_parallel are both enabled, self.embedding is temporarily replaced with a plain function:

if self.mtp_process and self.config.sequence_parallel:
    _original_embedding = self.embedding

    def _sp_scatter_embedding(input_ids, position_ids):
        out = _original_embedding(input_ids=input_ids, position_ids=position_ids)
        return tensor_parallel.scatter_to_sequence_parallel_region(out)

    self.__dict__["embedding"] = _sp_scatter_embedding  # module → function

Then _postprocess() is called while self.embedding is still the shadow function. For models with share_embeddings_and_output_weights=True, _postprocess calls shared_embedding_or_output_weight() which accesses self.embedding.word_embeddings.weight — but the function has no word_embeddings attribute.

Both 4B and 9B have MTP enabled (mtp_num_hidden_layers=1), so the shadow is applied for both. The 9B model survives because share_embeddings_and_output_weights=False skips the problematic code path.

Stack Trace

text_model.py:194    → self._postprocess(...)
gpt_model.py:614     → output_weight = self.shared_embedding_or_output_weight()
language_module.py:321 → return self.embedding.word_embeddings.weight
                         AttributeError: 'function' object has no attribute 'word_embeddings'

Suggested Fix

Preserve word_embeddings on the shadow function (one-line fix):

_sp_scatter_embedding.word_embeddings = _original_embedding.word_embeddings
self.__dict__["embedding"] = _sp_scatter_embedding

Affected Models

Any Qwen3.5 dense model where tie_word_embeddings=True and MTP is enabled. Currently this includes Qwen3.5-4B and Qwen3.5-2B.

Minimal repro

Run GRPO training with `Qwen/Qwen3.5-4B-Base` using Megatron-Bridge backend with tensor_parallel=4 and `sequence_parallel=True` in NeMo-RL.

Expected behavior

Qwen3.5 dense models with tie_word_embeddings=True (2B/4B) should run without error.

Affected area

area:model

Regression?

Not sure

Environment

No response

Logs

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingneeds-triageNew item needs classification and ownershipqa_rcca_done

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions