Skip to content

μP: Maximal Update Parameterization #3058

Merged
BoxiangW merged 12 commits into
NVIDIA:mainfrom
plugyawn:feature/mup
Feb 26, 2026
Merged

μP: Maximal Update Parameterization #3058
BoxiangW merged 12 commits into
NVIDIA:mainfrom
plugyawn:feature/mup

Conversation

@plugyawn

@plugyawn plugyawn commented Jan 23, 2026

Copy link
Copy Markdown
Contributor

What does this PR do ?

Adds support for Maximal Update Parameterization (μP) for optimal hyperparameter transfer across model widths.
Addresses issue #2824 opened by @sbhavani.

The idea is to train multiple high-depth, low-width models to recover optimal HPs (i.e, reduced hidden_size), and then transfer to high-width models (i.e, high hidden_size).

Automatic initialization scaling (σ / √(width_mult) for hidden layers) and automatic LR scaling: lr / width_mult for hidden layers (Adam only, not SGD) is also implemented. Embedding/output layers use base LR (no scaling), as in the original TP-V paper.

References:

Tagging the @mcore-oncall

Functional tests and documentation in progress, unit tests added.
Some doubts: in the param_and_grad_buffer, I added an is_embedding_parameter even though there exists an is_embedding_or_output_parameter. In TP-V, the fan-in of the output layer is interpreted as infinite-width, inlike the embedding layer, which has the fixed vocabulary (both according to the paper and mutransformers). There seem to be conflicts about the case of Tied Embeddings (embedding and output layer share weights, see this discussion.

Some plots:

The following plots show the current functioning. In the first image, please note that MuP is on a different Y-axis scale than SP. In the second, I believe training more longer would make the MuP sharing optimal LR much clearer (it's currently on 500 steps). I only have access to an A100 at the moment, so these are character-level transformers trained on enwiki8.

image image image

Experiment details:

Parameter Default Paper Reference
--widths 128,256,512,1024,2048,4096,8192 MuP paper Fig. 1
--base-hidden-size 128 Base model for width_mult
--num-layers 4 Transformer depth
--lr-sweep-steps 500 Steps per (width, LR) run
--num-seeds 3 -
--seq-len (for LR sweep) 128 -
--batch_size 4 -

Plotting code can be accessed on fork feature/mup-implementation on my fork of the code.

⚠️ For major changes (either in lines of code or in its impact), please make sure to first share a design doc with the team. If you're unsure what's the best way to do so, contact the @mcore-oncall.

Contribution process

flowchart LR
    A[Pre-checks] --> B[PR Tests]
    subgraph Code Review/Approval
        C1[Expert Review] --> C2[Final Review]
    end
    B --> C1
    C2 --> D[Merge]
Loading

Pre-checks

  • I want this PR in a versioned release and have added the appropriate Milestone (e.g., Core 0.8)
  • I have added relevant unit tests
  • I have added relevant functional tests [incoming]
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

The following process is enforced via the CODEOWNERS file for changes into megatron/core. For changes outside of megatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.

For MRs into `main` branch

Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

(Step 1): Add PR label Expert Review

(Step 2): Collect the expert reviewers reviews

  1. Attach the Expert Review label when your PR is ready for review.
  2. GitHub auto-assigns expert reviewers based on your changes. They will get notified and pick up your PR soon.

⚠️ Only proceed to the next step once all reviewers have approved, merge-conflict are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

(Step 3): Final Review

  1. Add Final Review label
  2. GitHub auto-assigns final reviewers based on your changes. They will get notified and pick up your PR soon.

(Optional Step 4): Cherry-pick into release branch

If this PR also needs to be merged into core_r* release branches, after this PR has been merged, select Cherry-pick to open a new PR into the release branch.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

Merging your PR

Any member of core-adlr and core-nemo will be able to merge your PR.

@plugyawn plugyawn requested review from a team as code owners January 23, 2026 20:25
@copy-pr-bot

copy-pr-bot Bot commented Jan 23, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@ko3n1g ko3n1g requested a review from a team January 23, 2026 20:25
@plugyawn

Copy link
Copy Markdown
Contributor Author

Note that this is just width-MuP (the original paper). There's also a new depth-MuP (which would help with not having to train "skinny" models with low-width/high-depth for transfer). A new paper, Complete(d)-P also exists, that I've not entirely gone through.

@plugyawn plugyawn changed the title μP: Maximal Update Parameterization [Draft] μP: Maximal Update Parameterization Jan 23, 2026
@plugyawn

Copy link
Copy Markdown
Contributor Author

Hi! @sbhavani, could you take a look?

@sbhavani

Copy link
Copy Markdown
Contributor

@plugyawn thanks for the contribution! Please bear with us as this will take some time to review since it touches a lot of areas in core

@plugyawn

plugyawn commented Jan 26, 2026

Copy link
Copy Markdown
Contributor Author

Thank you, @sbhavani and the team!

@KyroChi KyroChi left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall I think this is a good start. I highly recommend you put these changes through Claude or ChatGPT to fix several docstring inconsistencies where the content of the docstring does not match the actual behavior of the code.

Some other things that may be potentially problems:

  1. You don't actually plot the minima in the figure of val loss vs lr on the right. Sometimes if not properly implemented you can in fact see the optimum shifting, but we won't be able to tell from this plot since we don't see the minima.
  2. You should plot normalized logits in the coordinate checks since unnormalized logits can hide subtle bugs, like an m^{1/4} dependency or something. When normalized we expect roughly horizontal lines.
  3. The output multipliers are not automatically set. This seems potentially dangerous, as the default behavior will have outputs which are m times larger than we would expect for muP. I left a comment about this. I want to make sure it was set for your experiments that you plot above. I couldn't confirm from my brief perusal of your fork.

@Skylion007 🫡

Comment thread megatron/core/transformer/transformer_config.py
@plugyawn

plugyawn commented Jan 29, 2026

Copy link
Copy Markdown
Contributor Author
image

You don't actually plot the minima in the figure of val loss vs lr on the right. Sometimes if not properly implemented you can in fact see the optimum shifting, but we won't be able to tell from this plot since we don't see the minima.

Plotted with the optimum marked. It's on log scale... so the shifted minima is,

Width SP opt LR MuP opt LR
128 9.77e-04 9.77e-04
256 4.88e-04 9.77e-04
512 2.44e-04 9.77e-04
1024 1.22e-04 9.77e-04
2048 6.10e-05 1.95e-03
Also 16 instead of 11 LRs, since the minima wasn't clear.

The shift in 2048 is worrying... but the minimum loss looks closeby.

Getting back with coordinate checks soon.

@chtruong814 chtruong814 removed the needs-follow-up Issue needs follow-up label Jan 29, 2026
@plugyawn

plugyawn commented Jan 29, 2026

Copy link
Copy Markdown
Contributor Author
image Normalized logits. T is gradient steps after training.

@KyroChi

KyroChi commented Jan 29, 2026

Copy link
Copy Markdown

I don't think that the optimum shift for the 2048 model is concerning for two reasons:

  1. The loss is almost identical, we expect mup to only hold on average, which means that occasionally the empirical optimum will shift a little bit between runs. The curvature of the lr vs. loss curves usually decreases as the model size increases which only adds to this issue.
  2. The 2048 models are almost certainly undertrained which will ALWAYS favor a larger learning rate.

Regarding this latter point, transfer is really only expected to occur at fixed TPP or something, but due to the asymptotic properties of mup we can usually just get away with optimally training our largest model and overtraining smaller models to demonstrate mu-transfer. You see only 4128500 =256,000 tokens during training, which is ~1 TPP for the smallest model and ~0.003 TPP for the largest model 😝 In my experience you usually need at least 2TPP to get good transfer plots, so in some sense this is already better than I would expect!

Regarding the residual power dependency in the logits plot: this could be because Megatron's default optimizer Adam eps is 10^{-8}, which is actually rather high for mup. Since your models are pretty small you can probably get away with setting this to 10^{-12} or even 10^{-15} and see if the coordinate check flattens. This is a quirk of all mup implementations unfortunately and is unlikely to indicate a bug IMO. See also this paper.

@BoxiangW

Copy link
Copy Markdown
Contributor

Hi @plugyawn, thanks for your contribution here, I was actually implementing this on my branch but since you have all the experiments result ready, we can try merge your PR!

Comment thread megatron/core/optimizer/__init__.py Outdated
Comment thread megatron/core/transformer/transformer_config.py
Comment thread megatron/core/transformer/transformer_config.py Outdated
@plugyawn

Copy link
Copy Markdown
Contributor Author

I don't think that the optimum shift for the 2048 model is concerning for two reasons:

  1. The loss is almost identical, we expect mup to only hold on average, which means that occasionally the empirical optimum will shift a little bit between runs. The curvature of the lr vs. loss curves usually decreases as the model size increases which only adds to this issue.
  2. The 2048 models are almost certainly undertrained which will ALWAYS favor a larger learning rate.

That makes sense! Thank you!

Regarding this latter point, transfer is really only expected to occur at fixed TPP or something, but due to the asymptotic properties of mup we can usually just get away with optimally training our largest model and overtraining smaller models to demonstrate mu-transfer. You see only 4_128_500 =256,000 tokens during training, which is ~1 TPP for the smallest model and ~0.003 TPP for the largest model 😝 In my experience you usually need at least 2TPP to get good transfer plots, so in some sense this is already better than I would expect!
I did not know we had good heuristics for when transfer happens! That makes sense!
I had some intuition, of course, that it must take some time to saturate, but I thought 256,000 should be close to enough... given the smallness of the vocabulary?

I was also reminded of https://arxiv.org/abs/2501.16975 and their scaling laws over vocabulary (min. loss decreases with log vocab), although that might be unrelated.

Regarding the residual power dependency in the logits plot: this could be because Megatron's default optimizer Adam eps is 10^{-8}, which is actually rather high for mup. Since your models are pretty small you can probably get away with setting this to 10^{-12} or even 10^{-15} and see if the coordinate check flattens. This is a quirk of all mup implementations unfortunately and is unlikely to indicate a bug IMO. See also this paper.

That makes sense!

Hahaha and also thanks this sent me into a rabbit hole, learned quite a few things!

Comment thread megatron/core/transformer/transformer_config.py
@chtruong814 chtruong814 added the needs-follow-up Issue needs follow-up label Feb 1, 2026
@janEbert

Copy link
Copy Markdown
Contributor

Thank you for the endless endurance and spirit! Let's get this merged. :)

@janEbert

Copy link
Copy Markdown
Contributor

/ok to test a66ae5b

@janEbert

Copy link
Copy Markdown
Contributor

Think you just need to rebase onto main and apply tools/autoformat.sh.

@plugyawn

Copy link
Copy Markdown
Contributor Author

Reran the plots for SGD as well, to be sure:
image

Autoformat's done, too!

Thank you for the endless endurance and spirit! Let's get this merged. :)

It was very fun!

@BoxiangW BoxiangW enabled auto-merge February 26, 2026 18:25
@BoxiangW

Copy link
Copy Markdown
Contributor

/ok to test b1cd2d8

@BoxiangW

Copy link
Copy Markdown
Contributor

Just triggered one last CI, I will be merged if all passed, thanks @plugyawn !

@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/22463189540

Merged via the queue into NVIDIA:main with commit 310082a Feb 26, 2026
51 checks passed
BoxiangW pushed a commit to BoxiangW/Megatron-LM that referenced this pull request Mar 4, 2026
mehraakash added a commit to mehraakash/Megatron-Bridge that referenced this pull request Mar 5, 2026
…-NeMo#3058)

Apply per-parameter-class LR/eps scaling in setup_optimizer when
use_mup=True on the model config. Mirrors the get_mup_config_overrides
call added to MCore's setup_model_and_optimizer in NVIDIA/Megatron-LM#3058.

The μP config fields (use_mup, mup_base_hidden_size, mup_width_mult, etc.)
are already present via MCoreTransformerConfig inheritance — no model config
changes needed.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
mehraakash added a commit to NVIDIA-NeMo/Megatron-Bridge that referenced this pull request Mar 10, 2026
Apply per-parameter-class LR/eps scaling in setup_optimizer when
use_mup=True on the model config. Mirrors the get_mup_config_overrides
call added to MCore's setup_model_and_optimizer in NVIDIA/Megatron-LM#3058.

The μP config fields (use_mup, mup_base_hidden_size, mup_width_mult, etc.)
are already present via MCoreTransformerConfig inheritance — no model config
changes needed.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Akash Mehra <akamehra@nvidia.com>
ilml added a commit to ilml/Megatron-LM that referenced this pull request Mar 20, 2026
…A#3058)

New files:
  - tests/unit_tests/transformer/test_mup.py
ilml added a commit to ilml/Megatron-LM that referenced this pull request Mar 20, 2026
These test files import from existing modules that are modified in Phase 2:
- test_rmsnorm_residual_fusion.py: imports TEFusedResidualRMSNorm (added in NVIDIA#3384)
- test_mup.py: imports get_mup_config_overrides (added in NVIDIA#3058)
- test_multimodule_schedules.py: imports MultiModuleProcessGroupCollection (added in NVIDIA#3129)

They will be re-added in Phase 2 when the corresponding code changes land.

Made-with: Cursor
yangbofun pushed a commit to xlm-research/Megatron-LM that referenced this pull request May 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-request Final Review PR is in the "final review" stage

Projects

None yet

Development

Successfully merging this pull request may close these issues.