Skip to content

Adding code for Flextron#4429

Merged
sheliang-nv merged 38 commits into
NVIDIA:mainfrom
sheliang-nv:shel/flex_merge
May 1, 2026
Merged

Adding code for Flextron#4429
sheliang-nv merged 38 commits into
NVIDIA:mainfrom
sheliang-nv:shel/flex_merge

Conversation

@sheliang-nv

Copy link
Copy Markdown
Contributor

What does this PR do ?

This PR lands Flextron (also known as Nemotron Elastic / Star Elastic) into Megatron-LM. Flextron is a post-training method that converts a single parent LLM into a nested family of submodels at different parameter budgets — all produced from one training run, all sharing a single checkpoint. A learnable router maps a user-specified budget to per-axis architectural decisions (embedding width, attention heads, Mamba heads, MoE experts, FFN channels); smaller submodels are strict subsets of larger ones via importance-ranked contiguous slicing, and all variants are trained jointly with knowledge distillation from the frozen parent.

Flextron has been used to produce the elastic variants shipped with Nemotron Nano v2 (12B → 9B + 6B) and Nemotron Nano v3 (30B/3.6A MoE → 23B/2.8A + 12B/2.0A). Until now the implementation has lived on private dev branches. This PR consolidates that work into main so it can be open-sourced and maintained alongside the rest of the Megatron-LM post-training surface.

⚠️ For major changes (either in lines of code or in its impact), please make sure to first share a design doc with the team. If you're unsure what's the best way to do so, contact the @mcore-oncall.

Files at a glance

  • megatron/elastification/ — new module (manager, hooks, router, budget math, config).
  • pretrain_mamba_flex.py — training entry point with per-microbatch budget sampling.
  • megatron/core/distributed/finalize_model_grads.py — all-reduces router grads across PP ranks, gated on config.flextron.
  • megatron/post_training/model_builder.py — teacher-config overrides so KD teachers don't carry the router.
  • tests/unit_tests/elastification/ — 10 test files.
  • tests/functional_tests/test_cases/hybrid/hybrid_flextron_nightly_*/ + tests/test_utils/recipes/h100/flextron.yaml — nightly functional test.

Contribution process

Pre-checks

  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

All PRs start as draft. If you open a non-draft PR, it will be automatically converted to draft.

Step 1: Mark PR as "Ready for Review"

  1. When your PR is ready, click Ready for Review.
  2. An oncall reviewer is auto-assigned and expert reviewers are notified based on your changes.
    • Some PRs may jump straight to step 2. This is determined by .github/CODEOWNERS.

⚠️ Only mark as ready once merge-conflicts are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

Step 2: Final Review

For PRs that change megatron/core, once all expert reviewers have approved, the Final Review label is applied automatically and final reviewers are assigned.

For PRs outside megatron/core, this step is skipped.

Step 3: Approved

Once all required reviewers have approved, the Approved label is applied automatically.

Merge

Any member of mcore-engineers will be able to merge your PR.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

@sheliang-nv sheliang-nv requested review from a team as code owners April 22, 2026 17:00
@copy-pr-bot

copy-pr-bot Bot commented Apr 22, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@svcnvidia-nemo-ci svcnvidia-nemo-ci marked this pull request as draft April 22, 2026 17:00
@github-actions

Copy link
Copy Markdown
Contributor

This PR has been automatically converted to draft because all PRs must start as drafts.

When you are ready for review, click Ready for Review to begin the review process. This will:

  1. Add the oncall reviewer (optional reviewer)
  2. Add required review teams based on your changes

See the contribution guide for more details.

@sheliang-nv sheliang-nv marked this pull request as ready for review April 22, 2026 17:08
@svcnvidia-nemo-ci svcnvidia-nemo-ci requested a review from a team April 22, 2026 17:09

@Phlip79 Phlip79 left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MambaModel has been renamed to HybridModel as of #4099. Can you please update this PR accordingly?

@sheliang-nv

Copy link
Copy Markdown
Contributor Author

/ok to test 892ca2f

@svcnvidia-nemo-ci svcnvidia-nemo-ci added this to the Core 0.16 milestone Apr 22, 2026
@sheliang-nv

Copy link
Copy Markdown
Contributor Author

/ok to test 8b9ffae

@sheliang-nv

Copy link
Copy Markdown
Contributor Author

/ok to test 044965d

@Phlip79 Phlip79 removed the request for review from a team April 23, 2026 00:37

@ChenhanYu ChenhanYu left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kevalmorabia97 and @AAnoosheh to review.

# Mamba
def mamba_params(mamba_nheads):
d_inner = mamba_nheads * mamba_d_head
ngroups = 8

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: ngroups is hardcoded to 8, but the actual model uses config.mamba_num_groups which is configurable. If someone sets mamba_num_groups to a value other than 8, the parameter count estimation (and therefore the budget loss) will be silently wrong.

This should be passed in as a parameter from the caller, which has access to config.mamba_num_groups. The same hardcoded value appears in the mamba_in_proj computation on line 121.

Suggested change
ngroups = 8
ngroups = 8 # TODO: pass mamba_num_groups from config instead of hardcoding


import random

import numpy as np

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: numpy is imported but never used in this file.

Suggested change
import numpy as np

Comment on lines +278 to +314
def _allreduce_router_grads(model: List[torch.nn.Module], config: TransformerConfig):
"""
All-reduce router grads.

Reduce grads across all the pp stages to ensure that parameters of the router stay in sync.
"""

if parallel_state.get_pipeline_model_parallel_world_size() > 1:
grads_dict: Dict[str, List[torch.Tensor]] = {}
for model_chunk in model:
for name, param in get_attr_wrapped_model(model_chunk, 'named_parameters')():
if param.requires_grad and getattr(param, 'flextron_router_pp_sync', False):
grad = param.main_grad
if name in grads_dict:
# Add all the virtual PP rank's gradients to
# the first local virtual PP rank.
grads_dict[name][0].add_(grad)
# Append to the end for later update after cross-rank reduce.
grads_dict[name].append(grad)
else:
grads_dict[name] = [grad]

if grads_dict:
# All-reduce the gradient on the first VPP rank.
grads = [param_grad[0] for _, param_grad in grads_dict.items()]
coalesced = _flatten_dense_tensors(grads)
torch.distributed.all_reduce(
coalesced, group=parallel_state.get_pipeline_model_parallel_group()
)
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)

# Update the gradients on other VPP ranks.
for grads in grads_dict.values():
for grad in grads[1:]:
grad.copy_(grads[0])

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This new function modifies a core distributed file but has no unit test coverage. A test verifying the all-reduce behavior (especially the VPP gradient aggregation logic in lines 291-313) would help prevent regressions, since bugs here would silently produce incorrect router gradients across pipeline stages.

@claude

claude Bot commented Apr 30, 2026

Copy link
Copy Markdown
Contributor

Missing test coverage for elasticity hook managers

The PR adds dedicated unit tests for FlextronMambaElasticityManager, FlextronStackElasticityManager, and FlextronTransformerLayerElasticityManager, but four other managers have no unit tests:

  • FlextronMoEElasticityManager — output masking for MoE layers
  • FlextronGroupedMLPElasticityManager — multi-hook MLP masking with FC1 intermediate masking and expert-TP-aware splitting
  • FlextronAttentionElasticityManager — QKV scaling and embedding masking for attention
  • FlextronTopKRouterElasticityManager — replaces the routing method on TopKRouter with a custom topk_softmax_with_capacity that applies expert masking

These contain non-trivial logic (especially the TopKRouter replacement and GroupedMLP's expert-tensor-parallel mask splitting). Consider adding targeted tests for at least the TopKRouter and GroupedMLP managers.

@sheliang-nv

Copy link
Copy Markdown
Contributor Author

/claude review

@claude claude Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@svcnvidia-nemo-ci svcnvidia-nemo-ci added Approved All necessary approvals have been made and removed Final Review PR is in the "final review" stage labels Apr 30, 2026
@@ -0,0 +1,210 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: make copyright year 2026.

@@ -0,0 +1,542 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: copyright year 2026.

@svcnvidia-nemo-ci svcnvidia-nemo-ci added Final Review PR is in the "final review" stage and removed Approved All necessary approvals have been made labels Apr 30, 2026
@sheliang-nv

Copy link
Copy Markdown
Contributor Author

https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/jobs/309002773
Link to passing internal functional test

@sheliang-nv sheliang-nv enabled auto-merge April 30, 2026 23:26
@sheliang-nv

Copy link
Copy Markdown
Contributor Author

/ok to test 3292cf7

@sheliang-nv

Copy link
Copy Markdown
Contributor Author

/ok to test 7e4c52e

@sheliang-nv sheliang-nv added this pull request to the merge queue May 1, 2026
@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25196645789

@github-merge-queue github-merge-queue Bot removed this pull request from the merge queue due to failed status checks May 1, 2026
@sheliang-nv sheliang-nv added this pull request to the merge queue May 1, 2026
@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25225179430

Merged via the queue into NVIDIA:main with commit 2d862fe May 1, 2026
65 of 67 checks passed
@sheliang-nv sheliang-nv deleted the shel/flex_merge branch May 1, 2026 18:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

complexity: high Final Review PR is in the "final review" stage

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants