Skip to content

[PEFT, ckpts] feat: modelopt for LoRA & deepseek arch#3612

Merged
cuichenx merged 4 commits into
NVIDIA-NeMo:mainfrom
HollowMan6:modelopt
May 5, 2026
Merged

[PEFT, ckpts] feat: modelopt for LoRA & deepseek arch#3612
cuichenx merged 4 commits into
NVIDIA-NeMo:mainfrom
HollowMan6:modelopt

Conversation

@HollowMan6

Copy link
Copy Markdown
Member

What does this PR do ?

As MoE with modelopt sets moe_grouped_gemm disabled, https://github.com/NVIDIA/Megatron-LM/blob/12f18dafbf9ea1a947f06c7aecde0208c0ada161/megatron/core/post_training/modelopt/gpt/model_specs.py#L146 additional mappings are needed here. Also, modelopt linear layers should be correctly recognized for lora.

Changelog

  • Add additional mappings support with moe_grouped_gemm disabled for Deepseek arch.
  • Support wrapping modelopt linear layers for LoRA.

GitHub Actions CI

See the CI sectionin the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

If you haven't finished some of the above items you can still open "Draft" PR.

Additional Information

  • Related to # (issue)

Copilot AI review requested due to automatic review settings April 30, 2026 21:13
@copy-pr-bot

copy-pr-bot Bot commented Apr 30, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@HollowMan6

Copy link
Copy Markdown
Member Author

/claude review

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends Megatron-Bridge’s PEFT (LoRA) and DeepSeek checkpoint-conversion support to better handle ModelOpt and DeepSeek MoE naming differences (notably when moe_grouped_gemm is disabled).

Changes:

  • Add detection + adapter-attribute handling for ModelOpt’s local Megatron Linear so it can be LoRA-wrapped correctly.
  • Route ModelOpt Linear modules away from the nn.Linear fast-path in both LoRA and CanonicalLoRA transforms.
  • Extend DeepSeek parameter mappings to cover local_experts naming and fix MTP mapping wildcard replacement to only target the intended wildcard groups.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.

File Description
src/megatron/bridge/peft/utils.py Adds is_modelopt_linear() and a ModelOpt-specific AdapterAttributes return path.
src/megatron/bridge/peft/lora.py Ensures ModelOpt Linear does not go through the nn.Linear/TE adapter path.
src/megatron/bridge/peft/canonical_lora.py Same exclusion for CanonicalLoRA’s nn.Linear fast-path.
src/megatron/bridge/models/deepseek/common.py Adds DeepSeek local_experts mappings and corrects MTP wildcard replacement behavior.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread src/megatron/bridge/peft/utils.py
Comment thread src/megatron/bridge/peft/lora.py
Comment thread src/megatron/bridge/peft/canonical_lora.py
@cuichenx

cuichenx commented May 1, 2026

Copy link
Copy Markdown
Contributor

/claude review #3612

Comment thread src/megatron/bridge/peft/utils.py Outdated
def is_modelopt_linear(m: nn.Module) -> bool:
"""Return whether a module is ModelOpt's local Megatron Linear."""
cls = type(m)
return cls.__name__ == "Linear" and cls.__module__ == "megatron.core.post_training.modelopt.layers"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks a bit fragile?

you can put import inside check or guard it up

from megatron.core.post_training.modelopt.layers import Linear as ModelOptLinear

def is_modelopt_linear(m: nn.Module) -> bool:
return isinstance(m, ModelOptLinear)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reviewing, I just changed this into safe_import_from similar to those on top (i.e. TEColumnParallelLinear)

@yaoyu-33 yaoyu-33 added the waiting-on-customer Waiting on the original author to respond label May 3, 2026
@HollowMan6 HollowMan6 requested a review from yaoyu-33 May 3, 2026 07:58
HollowMan6 added 3 commits May 4, 2026 22:54
As MoE with modelopt sets `moe_grouped_gemm` disabled,
https://github.com/NVIDIA/Megatron-LM/blob/12f18dafbf9ea1a947f06c7aecde0208c0ada161/megatron/core/post_training/modelopt/gpt/model_specs.py#L146
additional mappings are needed here.

Also, modelopt linear layers should be correctly
recognized for lora.

Signed-off-by: Hollow Man <hollowman@opensuse.org>
Signed-off-by: Hollow Man <hollowman@opensuse.org>
Signed-off-by: Hollow Man <hollowman@opensuse.org>
@HollowMan6

Copy link
Copy Markdown
Member Author

/ok to test dfd1c12

Signed-off-by: Hollow Man <hollowman@opensuse.org>

@cuichenx cuichenx left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@cuichenx cuichenx merged commit f6041a2 into NVIDIA-NeMo:main May 5, 2026
86 checks passed
@HollowMan6 HollowMan6 deleted the modelopt branch May 5, 2026 18:45
gautham-kollu pushed a commit that referenced this pull request May 12, 2026
Signed-off-by: Hollow Man <hollowman@opensuse.org>
vasunvidia pushed a commit to vasunvidia/Megatron-Bridge that referenced this pull request Jun 10, 2026
Signed-off-by: Hollow Man <hollowman@opensuse.org>
Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

waiting-on-customer Waiting on the original author to respond

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants