Skip to content

🐛 CI failure: MultiTokenPredictionLayer._checkpointed_forward() got unexpected kwarg 'padding_mask' #4933

@ko3n1g

Description

@ko3n1g

Describe the bug

CI test tests/unit_tests/transformer/test_multi_token_prediction.py::TestMultiTokenPrediction is failing across many parametrizations with a TypeError on the recompute path:

megatron/core/transformer/multi_token_prediction.py:1301: TypeError
E   TypeError: MultiTokenPredictionLayer._checkpointed_forward() got an unexpected keyword argument 'padding_mask'

Failing nodes (all in the same job):

  • TestMultiTokenPrediction::test_forward_backward[{1,2,4}-{1,2,4}-True] (9 combinations)
  • TestMultiTokenPrediction::test_fp8_support[True]
  • TestMultiTokenPrediction::test_packed_sequences_with_full_recompute

Tag @NVIDIA/mcore-oncall to get oncall's attention to this issue.

Root cause (likely)

The call site at megatron/core/transformer/multi_token_prediction.py:1301 passes padding_mask=padding_mask into self._checkpointed_forward(...), but the method definition at megatron/core/transformer/multi_token_prediction.py:1093 does not declare a padding_mask parameter.

git blame points to two recent landings that don't compose:

  • 2d1fa8d372 (#2645, 2026-05-14, @Connor-XY) added padding_mask=padding_mask to the call site.
  • 2b77d32b1e (#4593, 2026-05-21, @BestJuly) refactored _checkpointed_forward without including padding_mask in the new signature.

The recompute branch (config.recompute_granularity == 'full' and self.training) crashes on every call.

Failing run

Field Value
PR #4931: test: enable NVTE_CUTEDSL_FUSED_GROUPED_MLP via pytest fixture (surfaced here; the PR itself does not touch MTP)
Run 26290097610
Job tests/unit_tests/transformer/test_multi_token_prediction.py - latest

Error (verbatim, abridged)

megatron/core/transformer/multi_token_prediction.py:1301: TypeError
E   TypeError: MultiTokenPredictionLayer._checkpointed_forward() got an unexpected keyword argument 'padding_mask'

FAILED tests/unit_tests/transformer/test_multi_token_prediction.py::TestMultiTokenPrediction::test_forward_backward[1-1-True]
FAILED tests/unit_tests/transformer/test_multi_token_prediction.py::TestMultiTokenPrediction::test_forward_backward[1-2-True]
FAILED tests/unit_tests/transformer/test_multi_token_prediction.py::TestMultiTokenPrediction::test_forward_backward[1-4-True]
FAILED tests/unit_tests/transformer/test_multi_token_prediction.py::TestMultiTokenPrediction::test_forward_backward[2-1-True]
FAILED tests/unit_tests/transformer/test_multi_token_prediction.py::TestMultiTokenPrediction::test_forward_backward[2-2-True]
FAILED tests/unit_tests/transformer/test_multi_token_prediction.py::TestMultiTokenPrediction::test_forward_backward[2-4-True]
FAILED tests/unit_tests/transformer/test_multi_token_prediction.py::TestMultiTokenPrediction::test_forward_backward[4-1-True]
FAILED tests/unit_tests/transformer/test_multi_token_prediction.py::TestMultiTokenPrediction::test_forward_backward[4-2-True]
FAILED tests/unit_tests/transformer/test_multi_token_prediction.py::TestMultiTokenPrediction::test_fp8_support[True]
FAILED tests/unit_tests/transformer/test_multi_token_prediction.py::TestMultiTokenPrediction::test_packed_sequences_with_full_recompute

Steps/Code to reproduce bug

Re-run the failing CI job linked above, or locally inside the dev container:

pytest tests/unit_tests/transformer/test_multi_token_prediction.py::TestMultiTokenPrediction::test_forward_backward

Additional context

Triaged automatically via /create-issue. Assigned to @BestJuly as the author of the refactor that dropped the padding_mask parameter; the fix is either to add padding_mask back to _checkpointed_forward (forwarding it into _proj_and_transformer_layer), or to drop the padding_mask=padding_mask argument at the call site if the recompute path doesn't need it.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions