Skip to content

Create a Protocol for the MLP layer of TransformerLayer#3435

Merged
ericharper merged 4 commits into
NVIDIA:mainfrom
nschank:mlplayer
May 10, 2026
Merged

Create a Protocol for the MLP layer of TransformerLayer#3435
ericharper merged 4 commits into
NVIDIA:mainfrom
nschank:mlplayer

Conversation

@nschank

@nschank nschank commented Feb 15, 2026

Copy link
Copy Markdown
Contributor

What does this PR do ?

Defines a Protocol representing the mlp submodule of TransformerLayer, and uses that instead of ModuleSpec to enable typechecking of its configuration.

  • To propagate type-checking in layerspec construction, this required I replace several layers of ModuleSpec with MlpBuilder. Nobody except a single unit test appears to be trying to introspect into the ModuleSpec's contents (other than a single use of metainfo which was easy to locally replace with some manual logic), so this should be fairly simple and safe.
  • TransformerLayer was doing some spicy internal kwarg-management based on the specific type being passed; I moved this logic into factory methods on the relevant classes themselves, and updated callers to prefer to pass that method directly instead, but the ModuleSpec-based special casing was left there for backward compatibility. Note that some of the types that were being special-cased are simply not supported by TransformerLayer at this point (they need to be wrapped in MoeLayer to support the correct forward interface), so I just removed them entirely.

Associated design doc: Typed ModuleSpec.pdf

Contribution process

flowchart LR
    A[Pre-checks] --> B[PR Tests]
    subgraph Code Review/Approval
        C1[Expert Review] --> C2[Final Review]
    end
    B --> C1
    C2 --> D[Merge]
Loading

Pre-checks

  • I want this PR in a versioned release and have added the appropriate Milestone (e.g., Core 0.8)
  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

The following process is enforced via the CODEOWNERS file for changes into megatron/core. For changes outside of megatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.

For MRs into `main` branch

Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

(Step 1): Add PR label Expert Review

(Step 2): Collect the expert reviewers reviews

  1. Attach the Expert Review label when your PR is ready for review.
  2. GitHub auto-assigns expert reviewers based on your changes. They will get notified and pick up your PR soon.

⚠️ Only proceed to the next step once all reviewers have approved, merge-conflict are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

(Step 3): Final Review

  1. Add Final Review label
  2. GitHub auto-assigns final reviewers based on your changes. They will get notified and pick up your PR soon.

(Optional Step 4): Cherry-pick into release branch

If this PR also needs to be merged into core_r* release branches, after this PR has been merged, select Cherry-pick to open a new PR into the release branch.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

Merging your PR

Any member of core-adlr and core-nemo will be able to merge your PR.

@nschank nschank requested review from a team as code owners February 15, 2026 18:58
@copy-pr-bot

copy-pr-bot Bot commented Feb 15, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@ko3n1g ko3n1g requested a review from a team February 15, 2026 18:58
@Phlip79 Phlip79 added Expert Review [deprecated] Apply this label to indicate that your PR is ready for expert review. complexity: medium labels Feb 17, 2026
@Phlip79

Phlip79 commented Feb 17, 2026

Copy link
Copy Markdown
Member

/ok to test e0cf1f9

@nschank

nschank commented Feb 18, 2026

Copy link
Copy Markdown
Contributor Author

Note: This one may be a bit blocked on the get_submodules method in #3426, because I think there are a good number of tests introspecting into mlp (see #3425) - will update once available.

@chtruong814 chtruong814 added the needs-follow-up Issue needs follow-up label Feb 20, 2026
@nschank

nschank commented Mar 7, 2026

Copy link
Copy Markdown
Contributor Author

I'm waiting on #3426 to resync, since I want to reuse the get_submodules thing

Comment thread megatron/core/transformer/mlp.py
@chtruong814 chtruong814 removed the needs-follow-up Issue needs follow-up label Mar 17, 2026
ffn_hidden_size: int | None = None,
) -> MLP:
"""Helper function to build an MLP as a TransformerLayer's mlp submodule."""
del is_mtp_layer

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, maybe dumb question... can you explain what is going on here?

Is this function taking in the arguments for a TransformerLayer and "converting" to an MLP or something?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure thing! Not a dumb question at all, this took some thought lol. This is trying to 'decentralize' the logic that TransformerLayer is currently doing here:

https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/transformer_layer.py#L373-L395

Basically, TransformerLayer is currently trying to introspect into the submodule being constructed, and to change the arguments it passes to build_module in response. But this has several drawbacks:

  1. It's confusing: you need to look in 3 different places to understand how these modules are being constructed (the config, the interface of the submodule, and TransformerLayer's special-cased conversion between them, hidden deep in its initializer).
  2. It's circular: TransformerLayer needs to know about its own dependencies in order to construct them, hence the lazy imports.
  3. It's inflexible: Only these classes get this special treatment, so for instance if a user subclasses one of them they suddenly get a different behavior.
  4. Finally (most relevantly here), it's type-checker incompatible: the parameters are different depending on what the caller provides, so at the very least I'd need to provide an overly flexible interface in order to specify the protocol.

So this PR is instead having TransformerLayer consistently construct its MLP submodule using exactly the same interface (and, in particular, it is the "maximal" interface that satisfies all current callers). It is then the responsibility of whoever constructs TransformerLayerSubmodules to satisfy that interface, regardless of the class they want to construct. If they want to provide a class that does not want to take all of the arguments, then that's not something TransformerLayer should be expected to fix for them - it can instead be handled by providing a callable which simply discards those arguments, and forwards the rest to the class they want to construct!

So these classmethods I added are basically that extra 'translation layer' - these classes sorta "know" they want to be provided to TransformerLayerSubmodules, so they can simply provide an extra method which satisfies the interface TransformerLayerSubmodules requires, and users can then provide MLP.as_mlp_submodule instead of just MLP. External users can easily imitate this pattern as well on their own custom classes (if they don't want to have their initializer accept all the arguments), or anyone can write such a 'translation function' for any alternative class they wish to construct too. So basically you get all the flexibility of the original TransformerLayer conversion thing, with none of the drawbacks.


I hope that helps clarify things - any thoughts on how to make this more self-documenting? Perhaps as_transformer_layer_submodule would convey this better?

@jaredcasper jaredcasper Mar 24, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Blegh, how did we ever let in those lines 373-395... <shaking head>.

The original idea was that all MLPs have the same interface. (and indeed, anything that can be swapped out in a spec should have the same interface as the thing its replacing.) It seems some MLPs have snuck in that have different interfaces and we get that ugly if/else block that shouldn't have passed review.

I don't get what the advantage is to having an "as_mlp_submodule" that takes the extra args then throws them away and creates the class vs just having the init function take the extra args directly and throw them away. Why the extra step? Why not just add is_mtp_layer to this classes __init__? Why should we let them provide a class that does not take all the extra arguments? If something is so different that it needs an entirely different argument list than it shouldn't be swapped in as an "MLP".

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get what the advantage is to having an "as_mlp_submodule" that takes the extra args then throws them away

It's not an advantage, it's a safe refactoring - I don't necessarily condone having the alternative interface for constructing these classes, but I'm trying to provide the minimal clean transition for each class currently being provided. Adding extra unused arguments in order to match a specific interface feels worse to me than providing a clean shim layer which callers can use easily.

More broadly, I think that "everyone needs to have the same interface" is technically a bit more restrictive than necessary - we need the thing passed in to have a particular interface, but I don't think it makes sense to force everyone to solely pass in unadulterated __init__ methods to call. Being able to pass subclasses which desire extra parameters (and then providing them using functools.partial) is a valuable ability, and I basically view this as the dual version of that (i.e. having a class which needs fewer/transformed parameters).

Why should we let them provide a class that does not take all the extra arguments? If something is so different that it needs an entirely different argument list than it shouldn't be swapped in as an "MLP".

Why should a class let someone provide an extra argument that it doesn't want to use? Not every interface needs to be perfectly met in order for something to be useful - IdentityFn isn't exactly the most choosy about how it's used, but it makes a lot of sense.

The parent module has a fixed amount of information it wants to provide, but it's not actually that helpful IMO to say that the 'constructee' should "care" about every single piece of information. If we have 5 classes that are legal MLP layers, and only one cares about is_mtp_layer, why does it make sense to require them all to specifically accept an unused parameter in their __init__? I don't really see that as cleaner than just having a documented shim layer which can let the class do its thing, while documenting the way that some other class uses this class.

This same pattern (having a callable which adapts to the appropriate interface) is something that is useful to demonstrate, so I think there's some value in putting it into the codebase somewhere. It provides flexibility (like not forcing you to update every tp_group recipient to pg_collection at once), and lets you do some useful things (like using a custom Module someone has somewhere, which was already being used by someone, which has a different interface but could be used in Megatron safely).


I would be fine with updating the other class's init's if you wanted, but I don't think I should be expected to convert any classes to switch between tp_group and pg_collection, so at the very least there will definitely be some conversion layer here - at least temporarily. Do you have a different form you'd prefer?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: https://github.com/NVIDIA/Megatron-LM/pull/3435/changes#diff-6745b82c932c5947fd3383c31f326639093c33d34cb5def59dc5a843d4e2ebbcR165 This is perhaps a good example of the additional flexibility that an intermediate callable can provide. Previously, in order to customize a submodule's parameters, it was necessary to actually subclass TransformerLayer and change how it was providing the parameter to the submodule; but now, you can simply intercept the parameters that TransformerLayer is providing and do whatever you want.

@chtruong814 chtruong814 added the needs-follow-up Issue needs follow-up label Mar 19, 2026
@chtruong814 chtruong814 added needs-follow-up Issue needs follow-up and removed needs-follow-up Issue needs follow-up labels Apr 17, 2026
@Phlip79 Phlip79 requested a review from a team April 20, 2026 18:26
@svcnvidia-nemo-ci svcnvidia-nemo-ci added waiting-on-maintainers Waiting on maintainers to respond and removed needs-follow-up Issue needs follow-up labels Apr 21, 2026
@ericharper

Copy link
Copy Markdown
Contributor

@nschank , could you resolve conflicts?

@svcnvidia-nemo-ci svcnvidia-nemo-ci added waiting-on-customer Waiting on the original author to respond and removed waiting-on-maintainers Waiting on maintainers to respond labels May 6, 2026
@nschank

nschank commented May 8, 2026

Copy link
Copy Markdown
Contributor Author

It kind of goes to heart of what the spec should be used for and what it shouldn't.

I get this, although I'm mostly just representing what's already happening haha.

Could you resolve conflicts

Yep! Sorry for delay, fixing up now

@nschank nschank requested a review from a team as a code owner May 8, 2026 21:10
@svcnvidia-nemo-ci svcnvidia-nemo-ci removed the Final Review PR is in the "final review" stage label May 8, 2026
@nschank

nschank commented May 8, 2026

Copy link
Copy Markdown
Contributor Author

Done, although a round of unit tests may reveal some new (untyped) areas that need updating - I'll fast follow with any additional little fixes

@ericharper ericharper enabled auto-merge May 8, 2026 21:30
@Phlip79

Phlip79 commented May 8, 2026

Copy link
Copy Markdown
Member

/ok to test beaef0b

@svcnvidia-nemo-ci svcnvidia-nemo-ci removed the waiting-on-customer Waiting on the original author to respond label May 9, 2026
@asolergi-nv

Copy link
Copy Markdown
Contributor

/ok to test e1a7c45

@svcnvidia-nemo-ci svcnvidia-nemo-ci added the Approved All necessary approvals have been made label May 10, 2026
@ericharper ericharper added this pull request to the merge queue May 10, 2026
@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25639642089

Merged via the queue into NVIDIA:main with commit 5e31514 May 10, 2026
62 checks passed
svcnvidia-nemo-ci added a commit that referenced this pull request May 12, 2026
Merges 8 commits from main into dev. Dev already contains yesterday's
sync (PR #4716) plus follow-up fixes, so this PR only carries main
commits made after that sync.

Notable changes:
- 434368c build(deps): bump nvidia-modelopt to 0.43 (#4723)
- e42e2fa ci: Major refactor of release-workflows (#4602)
- 33d47e0 [ci] fix: treat cancelled run-main-script step as failure (#4727)
- 5123f6a ci: revert bad uv.lock bump and label future bumps with
  Run functional tests (#4730)
- ad58411 Add Python-side guardrail for DeepEP IB limits (#4719)
- e93755e chore(beep boop): Bump (main) (2026-05-11)
- a2ec5c1 Revert Add Python-side guardrail for HybridEP IB limit (#4718)
- 5e31514 Create a Protocol for the MLP layer of TransformerLayer (#3435)

Kept dev's pyproject.toml, uv.lock, docker/Dockerfile.ci.dev, and
.github/CODEOWNERS (per nightly-sync skill).

Ran black + isort on changed Python files.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Approved All necessary approvals have been made community-request complexity: medium

Projects

None yet

Development

Successfully merging this pull request may close these issues.