Skip to content

[main] feat(moe): Support apply wd to qk layernorm for Qwen3-Next (4/4)#2753

Merged
chtruong814 merged 3 commits into
NVIDIA:mainfrom
yuzhongw-nvidia:qwen3next_wd
Jan 16, 2026
Merged

[main] feat(moe): Support apply wd to qk layernorm for Qwen3-Next (4/4)#2753
chtruong814 merged 3 commits into
NVIDIA:mainfrom
yuzhongw-nvidia:qwen3next_wd

Conversation

@yuzhongw-nvidia

@yuzhongw-nvidia yuzhongw-nvidia commented Dec 24, 2025

Copy link
Copy Markdown
Contributor

What does this PR do ?

MR to dev.

Design doc

Qwen3-Next functionality PRs.

Contribution process

flowchart LR
    A[Pre-checks] --> B[PR Tests]
    subgraph Code Review/Approval
        C1[Expert Review] --> C2[Final Review]
    end
    B --> C1
    C2 --> D[Merge]
Loading

Pre-checks

  • I want this PR in a versioned release and have added the appropriate Milestone (e.g., Core 0.8)
  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

The following process is enforced via the CODEOWNERS file for changes into megatron/core. For changes outside of megatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.

For MRs into `main` branch

(Step 1): Add PR label Expert Review

(Step 2): Collect the expert reviewers reviews

  1. Attach the Expert Review label when your PR is ready for review.
  2. GitHub auto-assigns expert reviewers based on your changes. They will get notified and pick up your PR soon.

⚠️ Only proceed to the next step once all reviewers have approved, merge-conflict are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

(Step 3): Final Review

  1. Add Final Review label
  2. GitHub auto-assigns final reviewers based on your changes. They will get notified and pick up your PR soon.

(Optional Step 4): Cherry-pick into release branch

If this PR also needs to be merged into core_r* release branches, after this PR has been merged, select Cherry-pick to open a new PR into the release branch.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

Merging your PR

Any member of core-adlr and core-nemo will be able to merge your PR.

@yuzhongw-nvidia yuzhongw-nvidia requested review from a team as code owners December 24, 2025 06:04
@copy-pr-bot

copy-pr-bot Bot commented Dec 24, 2025

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@yuzhongw-nvidia

Copy link
Copy Markdown
Contributor Author

/ok to test 5e4fd18

@yuzhongw-nvidia

Copy link
Copy Markdown
Contributor Author

/ok to test 5afcfdc

@yuzhongw-nvidia

Copy link
Copy Markdown
Contributor Author

/ok to test 3a29a80

@yuzhongw-nvidia

Copy link
Copy Markdown
Contributor Author

/ok to test a34da99

@yuzhongw-nvidia

Copy link
Copy Markdown
Contributor Author

/ok to test 11c9076

@Phlip79

Phlip79 commented Jan 14, 2026

Copy link
Copy Markdown
Member

/ok to test ac0fb11

)
combined_override[key] = value

# Overrides that force overrides.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this mean?

end_wd: float
wd_mult: float

_force_override: bool = False

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this have an underscore at the front?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#2968 would get rid of this part of the logic, but the answer is that this current implementation adds wd overrides twice, once setting those qk_layernorm layers to wd=0, then later setting them to wd=1. Setting things twice breaks the old logic, so the work-around was to add a "force override" concept. The underscore at the beginning tells the optimizer loop to not include this key as a thing to override in the parameter group. I think the proposal I added in the PR above accomplishes the goal in a cleaner way.

help='Dropout probability for hidden state transformer.')
group.add_argument('--weight-decay', type=float, default=0.01,
help='Weight decay coefficient for L2 regularization.')
group.add_argument('--apply-wd-to-qk-layernorm', action='store_true',

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand what this option does. In particular, "as a special case"? What's the general case then?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally all len==1 (eg layernorm weights) or bias terms get added to the wd=0 group. This says "do not add the q or k layernorm weights to the wd=0 group, leave them as wd=1".

Comment thread megatron/training/arguments.py Outdated
@deepakn94 deepakn94 requested a review from jstjohn January 15, 2026 05:41

@jstjohn jstjohn left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm nervous about how this approach of a force_override will interact with multiple different kinds of merging. For example it seems to work well for the case of only overriding wd, but what if we also want to override LR (eg decoupled lr) and that has partial overlap with parameters that need to be wd unskipped?

Here is an alternative approach, please feel free to cherry pick the commit over: #2968, (e57b2f5)

The new design works by adding a new kind of predictate that handles the tuple of param,name. That is sufficient for modifying the weight decay skip rule with your filter for qk_layernorm. @FDecaYed came up with this idea which would have simplified his matching rule in megatron bridge for this same problem.

yuzhongw-nvidia and others added 2 commits January 15, 2026 20:05
…upport Qwen3 weight decay

Signed-off-by: John St. John <jstjohn@nvidia.com>
@yuzhongw-nvidia

Copy link
Copy Markdown
Contributor Author

I'm nervous about how this approach of a force_override will interact with multiple different kinds of merging. For example it seems to work well for the case of only overriding wd, but what if we also want to override LR (eg decoupled lr) and that has partial overlap with parameters that need to be wd unskipped?

Here is an alternative approach, please feel free to cherry pick the commit over: #2968, (e57b2f5)

The new design works by adding a new kind of predictate that handles the tuple of param,name. That is sufficient for modifying the weight decay skip rule with your filter for qk_layernorm. @FDecaYed came up with this idea which would have simplified his matching rule in megatron bridge for this same problem.

Thanks @jstjohn and @FDecaYed for your help. Your implementation is much cleaner, so I cherry-pick your changes.

Hi @deepakn94 , could you please help take a look about the current version?

@yuzhongw-nvidia

Copy link
Copy Markdown
Contributor Author

/ok to test fc64403

@jstjohn jstjohn left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@deepakn94 deepakn94 left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thank you.

@deepakn94

Copy link
Copy Markdown
Contributor

/ok to test 2ababd2

@chtruong814

Copy link
Copy Markdown
Contributor

fast merging since the functional tests on main were passing. We had some issues with newer tests we were onboarding. This one shoudl have merged earlier.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

complexity: medium dev2main: mbridge dev to main: this PR is needed in main for mbridge Expert Review [deprecated] Apply this label to indicate that your PR is ready for expert review. Final Review PR is in the "final review" stage module: moe

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants