Fix MTP recompute crash with packed sequences#4593
Merged
Conversation
MultiTokenPredictionLayer._checkpointed_forward forwarded every kwarg positionally to tensor_parallel.checkpoint. CheckpointFunction's save_for_backward only accepts tensors and None, so a non-tensor kwarg like packed_seq_params (PackedSeqParams) raised: TypeError: save_for_backward can only save variables, but argument N is of type PackedSeqParams This made THD packed sequences + decoder full recompute (--recompute-granularity full) unusable whenever MTP was enabled. Mirror the closure pattern already used in transformer_block._checkpointed_forward and attention.py: take explicit named parameters; capture non-tensor objects (attention_bias, inference_params, packed_seq_params) in a closure called custom_forward; only forward tensor / None arguments positionally to (te_checkpoint | tensor_parallel.checkpoint). No semantic change on existing paths where all kwargs are tensor / None (the closure is empty and call is identical). Add unit test test_packed_sequences_with_full_recompute that fails without this fix (same TypeError) and passes with it. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
4 tasks
The previous refactor unconditionally pushed every kwarg through the closure pattern. That uncovered a pre-existing latent issue on the FP8 + recompute path: ``_proj_and_transformer_layer`` enters its ``fp8_autocast`` *inside* the function body, so by the time TE's ``_CheckpointFunction.forward`` reads ``is_fp8_enabled()`` to set ``_FP8_ACTIVATION_RECOMPUTE_ENABLED``, FP8 is not yet active. Phase 1 therefore never stashes the FP8 buffer position key, and phase 2's ``get_old_fp8_meta_tensors_for_recompute`` raises ``KeyError: global_fp8_buffer_pos_fwd_recompute``. Historically this never surfaced because, with the OLD kwargs-passing pattern, TE's ``_CheckpointFunction`` only tracks positional args via ``save_for_backward`` — kwargs tensors are stashed on ``ctx.kwargs`` unwired from autograd. As a result MTP's checkpoint backward never ran and the FP8 buffer lookup never executed. To fix the original ``PackedSeqParams`` save_for_backward crash without disturbing the FP8 path, take the OLD route on FP8 (forward ``_proj_and_transformer_layer`` directly with kwargs into ``te_checkpoint``, which stashes non-tensor args natively) and use the new closure pattern only on the non-FP8 path (where ``tensor_parallel.checkpoint`` requires tensor / None args). This restores ``test_fp8_support[True]`` while keeping ``test_packed_sequences_with_full_recompute`` green. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…r_block Two related fixes wrapped in one cleanup that matches transformer_block._checkpointed_forward's outer/inner quantization context design: 1. Pass differentiable activations (hidden_states, decoder_input) and all tensor / None args positionally to te_checkpoint, capturing the non-tensor objects (attention_bias, inference_params, packed_seq_params) via a custom_forward closure. TE's reentrant checkpoint only tracks positional tensor inputs as checkpoint inputs; kwarg tensors are dropped from the recompute backward path, silently detaching MTP activation gradients under fp8/mxfp8 + full recompute. Fixes the same regression as NVIDIA#4766 on top of the packed-sequence crash fix in this PR. 2. Enter the outer quantization context (get_fp8_context) only when fp8 + delayed scaling is active. TE's _CheckpointFunction.forward samples FP8GlobalStateManager.is_fp8_enabled() at entry to gate the phase-1 amax-buffer stash keyed by 'global_fp8_buffer_pos_fwd_recompute' used by phase-2 backward. Only delayed-scaling fp8 stashes/looks up that buffer (per TE's copy/get_old_fp8_meta_tensors_for_recompute); MXFP8 BlockScaling, Float8CurrentScaling, and NVFP4BlockScaling all treat it as a noop. Non-delayed quantized recipes rely on the inner context already entered inside _proj_and_transformer_layer. 3. Route fp4 through te_checkpoint as well (matches transformer_block.py:526). fp4 (NVFP4BlockScaling) is implemented on top of TE's fp8_autocast (see fp4_utils.get_fp4_context), so its recompute path needs te_checkpoint for correct handling. MTP _proj_and_transformer_layer itself still treats fp4 as a noop pending numerical validation; this change only prepares the recompute plumbing. Co-Authored-By: Zhongbo Zhu <zhongboz@nvidia.com> Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
a2c7582 to
51114ff
Compare
5 tasks
zhongbozhu
approved these changes
May 14, 2026
santhnm2
approved these changes
May 15, 2026
ericharper
approved these changes
May 15, 2026
Contributor
|
/ok to test 51114ff |
kvareddy
approved these changes
May 21, 2026
|
🔄 Merge queue validation started! You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/26208774108 |
copy-pr-bot Bot
pushed a commit
that referenced
this pull request
May 26, 2026
MultiTokenPredictionLayer.forward calls self._checkpointed_forward(
padding_mask=padding_mask, ...) (multi_token_prediction.py:1305), but
_checkpointed_forward and its inner custom_forward never accepted
padding_mask. With recompute_granularity == 'full' and self.training,
this raised:
TypeError: MultiTokenPredictionLayer._checkpointed_forward() got
an unexpected keyword argument 'padding_mask'
at multi_token_prediction.py:1301. The kwarg was introduced in #2645
on the call site; the _checkpointed_forward refactor in #4593 dropped
padding_mask from the recompute path.
Add padding_mask:
* to _checkpointed_forward's signature
* to custom_forward's signature so it flows into _proj_and_transformer_layer
* positionally to te_checkpoint and tensor_parallel.checkpoint, matching the
other tensor / None args (padding_mask is a rolled tensor, not a non-tensor
closure-captured arg like attention_bias)
* to the recompute_method == 'block' fallback that also calls
_proj_and_transformer_layer directly
Also remove the @pytest.mark.flaky_in_dev markers from
test_forward_backward, test_fp8_support, and test_packed_sequences_with_full_recompute,
which were added in #4931 to mask this exact failure.
Closes #4933
Signed-off-by: oliver könig <okoenig@nvidia.com>
BestJuly
pushed a commit
to BestJuly/Megatron-LM
that referenced
this pull request
May 26, 2026
MultiTokenPredictionLayer.forward calls self._checkpointed_forward(
padding_mask=padding_mask, ...) (multi_token_prediction.py:1305), but
_checkpointed_forward and its inner custom_forward never accepted
padding_mask. With recompute_granularity == 'full' and self.training,
this raised:
TypeError: MultiTokenPredictionLayer._checkpointed_forward() got
an unexpected keyword argument 'padding_mask'
at multi_token_prediction.py:1301. The kwarg was introduced in NVIDIA#2645
on the call site; the _checkpointed_forward refactor in NVIDIA#4593 dropped
padding_mask from the recompute path.
Add padding_mask:
* to _checkpointed_forward's signature
* to custom_forward's signature so it flows into _proj_and_transformer_layer
* positionally to te_checkpoint and tensor_parallel.checkpoint, matching the
other tensor / None args (padding_mask is a rolled tensor, not a non-tensor
closure-captured arg like attention_bias)
* to the recompute_method == 'block' fallback that also calls
_proj_and_transformer_layer directly
Also remove the @pytest.mark.flaky_in_dev markers from
test_forward_backward, test_fp8_support, and test_packed_sequences_with_full_recompute,
which were added in NVIDIA#4931 to mask this exact failure.
Closes NVIDIA#4933
Signed-off-by: oliver könig <okoenig@nvidia.com>
janEbert
pushed a commit
to janEbert/Megatron-LM
that referenced
this pull request
Jun 2, 2026
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
This was referenced Jun 5, 2026
mathemakitten
pushed a commit
to mathemakitten/Megatron-LM
that referenced
this pull request
Jun 12, 2026
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR bundles two related fixes in
MultiTokenPredictionLayer._checkpointed_forward, both addressed by aligning MTP's recompute plumbing withtransformer_block._checkpointed_forward.Fix 1 —
TypeErrorcrash on packed sequences + full recomputeMultiTokenPredictionLayer._checkpointed_forwardforwards every kwarg positionally totensor_parallel.checkpoint.CheckpointFunction.forwardcallsctx.save_for_backward(*args), which only accepts tensors andNone— so any non-tensor kwarg (today:packed_seq_params: PackedSeqParams; tomorrow potentiallyinference_params: InferenceContext) triggers:This breaks THD packed sequences + decoder full activation recompute (
--recompute-granularity full) whenever MTP is enabled (--mtp-num-layers >= 1).TransformerBlock._checkpointed_forwardandattention.py:_checkpointed_attention_forwardhave always avoided this by capturing non-tensor objects via Python closure and only forwarding tensor /Noneargs. MTP's_checkpointed_forward(introduced in !3330 "feat(MoE): support CP and recompute for MTP") didn't follow that convention.Fix 2 — silent grad-drop / FP8 recompute crash (same as #4766)
This PR also incorporates the fix from #4766. Two distinct problems on the FP8 / MXFP8 + MTP + full recompute path were reported there:
te_checkpointonly tracks positional tensor inputs as checkpoint inputs. The pre-existing MTP path passedhidden_states/decoder_inputas kwargs tote_checkpoint, so those tensors were not represented in the recompute backward path — gradients for MTP internals were silently detached, causing missing DDP grad-ready hooks and MTP loss divergence vs BF16._proj_and_transformer_layerenters itsfp8_autocastinside the function body, so by the time TE's_CheckpointFunction.forwardreadsFP8GlobalStateManager.is_fp8_enabled()to gate the phase-1 amax-buffer stash (keyed byglobal_fp8_buffer_pos_fwd_recompute), FP8 is not yet active. Phase 1 skips the stash and phase 2 raisesKeyError: global_fp8_buffer_pos_fwd_recompute. (Historically this was masked because TE never ran backward through kwarg tensors — fixing fix 2a above exposes this.)Changes (mirroring
transformer_block._checkpointed_forward)_checkpointed_forwardnow takes named parameters that match_proj_and_transformer_layer's signature.custom_forwardclosure capturesattention_bias,inference_params,packed_seq_paramsfrom the enclosing scope.Nonearguments are forwarded positionally totensor_parallel.checkpoint/te_checkpoint. This applies to both the non-FP8 and FP8 paths, so differentiable activations are correctly tracked in the recompute backward.get_fp8_context(self.config)is entered aroundte_checkpointonly whenfp8 + delayed scalingis active (per TE'scopy/get_old_fp8_meta_tensors_for_recompute, only delayed scaling actually uses the recompute amax buffer; MXFP8 BlockScaling, Float8CurrentScaling, NVFP4BlockScaling are all noops there). Mirrorstransformer_block'souter_quantization_contextdecision.te_checkpoint; non-quantized stays ontensor_parallel.checkpoint. Mirrorstransformer_block.py:526. MTP's internal fp4 path is still a noop pending numerical validation; this change only prepares the recompute plumbing.forward_funcfirst argument.No semantic change on the existing non-FP8 path where all kwargs are tensor /
None— the closure is empty and the call is identical.Companion PR for
dev: #4592E2E testing:

Note that
maindoes not yet carry the SFT-mock-data plumbing (MockSFTDataset,--sft-mock-dataset-config-json) that lives ondev, so onmainwe exercise the bug only via the unit test (which still fully hits the affected code path).Test plan
tests/unit_tests/transformer/test_multi_token_prediction.py::TestMultiTokenPrediction::test_packed_sequences_with_full_recomputeTypeError: save_for_backward can only save variables, but argument 10 is of type PackedSeqParams.main_grad.test_constructor_local[1],test_packed_sequences[1-1],test_roll_tensor_with_packed_sequences[1],test_fp8_support[False],test_fp8_support[True](the last one exercises the FP8 + full-recompute path that [Dev] Fix MTP layer recompute #4766 reported).🤖 Generated with Claude Code