Skip to content

Fix MTP recompute crash with packed sequences#4593

Merged
ericharper merged 6 commits into
NVIDIA:mainfrom
BestJuly:lit/fix_mtp_thd_recompute_main
May 21, 2026
Merged

Fix MTP recompute crash with packed sequences#4593
ericharper merged 6 commits into
NVIDIA:mainfrom
BestJuly:lit/fix_mtp_thd_recompute_main

Conversation

@BestJuly

@BestJuly BestJuly commented May 2, 2026

Copy link
Copy Markdown
Contributor

Summary

This PR bundles two related fixes in MultiTokenPredictionLayer._checkpointed_forward, both addressed by aligning MTP's recompute plumbing with transformer_block._checkpointed_forward.

Fix 1 — TypeError crash on packed sequences + full recompute

MultiTokenPredictionLayer._checkpointed_forward forwards every kwarg positionally to tensor_parallel.checkpoint. CheckpointFunction.forward calls ctx.save_for_backward(*args), which only accepts tensors and None — so any non-tensor kwarg (today: packed_seq_params: PackedSeqParams; tomorrow potentially inference_params: InferenceContext) triggers:

TypeError: save_for_backward can only save variables, but argument N is of type PackedSeqParams

This breaks THD packed sequences + decoder full activation recompute (--recompute-granularity full) whenever MTP is enabled (--mtp-num-layers >= 1).

TransformerBlock._checkpointed_forward and attention.py:_checkpointed_attention_forward have always avoided this by capturing non-tensor objects via Python closure and only forwarding tensor / None args. MTP's _checkpointed_forward (introduced in !3330 "feat(MoE): support CP and recompute for MTP") didn't follow that convention.

Fix 2 — silent grad-drop / FP8 recompute crash (same as #4766)

This PR also incorporates the fix from #4766. Two distinct problems on the FP8 / MXFP8 + MTP + full recompute path were reported there:

  • TE's reentrant te_checkpoint only tracks positional tensor inputs as checkpoint inputs. The pre-existing MTP path passed hidden_states / decoder_input as kwargs to te_checkpoint, so those tensors were not represented in the recompute backward path — gradients for MTP internals were silently detached, causing missing DDP grad-ready hooks and MTP loss divergence vs BF16.
  • _proj_and_transformer_layer enters its fp8_autocast inside the function body, so by the time TE's _CheckpointFunction.forward reads FP8GlobalStateManager.is_fp8_enabled() to gate the phase-1 amax-buffer stash (keyed by global_fp8_buffer_pos_fwd_recompute), FP8 is not yet active. Phase 1 skips the stash and phase 2 raises KeyError: global_fp8_buffer_pos_fwd_recompute. (Historically this was masked because TE never ran backward through kwarg tensors — fixing fix 2a above exposes this.)

Changes (mirroring transformer_block._checkpointed_forward)

  • _checkpointed_forward now takes named parameters that match _proj_and_transformer_layer's signature.
  • A custom_forward closure captures attention_bias, inference_params, packed_seq_params from the enclosing scope.
  • Only tensor / None arguments are forwarded positionally to tensor_parallel.checkpoint / te_checkpoint. This applies to both the non-FP8 and FP8 paths, so differentiable activations are correctly tracked in the recompute backward.
  • An outer get_fp8_context(self.config) is entered around te_checkpoint only when fp8 + delayed scaling is active (per TE's copy/get_old_fp8_meta_tensors_for_recompute, only delayed scaling actually uses the recompute amax buffer; MXFP8 BlockScaling, Float8CurrentScaling, NVFP4BlockScaling are all noops there). Mirrors transformer_block's outer_quantization_context decision.
  • Quantized recompute (fp8 or fp4) now routes through te_checkpoint; non-quantized stays on tensor_parallel.checkpoint. Mirrors transformer_block.py:526. MTP's internal fp4 path is still a noop pending numerical validation; this change only prepares the recompute plumbing.
  • The single call site drops the now-unused forward_func first argument.

No semantic change on the existing non-FP8 path where all kwargs are tensor / None — the closure is empty and the call is identical.

Companion PR for dev: #4592

E2E testing:
image

Note that main does not yet carry the SFT-mock-data plumbing (MockSFTDataset, --sft-mock-dataset-config-json) that lives on dev, so on main we exercise the bug only via the unit test (which still fully hits the affected code path).

Test plan

  • Added tests/unit_tests/transformer/test_multi_token_prediction.py::TestMultiTokenPrediction::test_packed_sequences_with_full_recompute
    • Without the fix: fails with the exact TypeError: save_for_backward can only save variables, but argument 10 is of type PackedSeqParams.
    • With the fix: forward + backward complete; all params get main_grad.
  • Existing MTP tests still pass on a single-rank smoke run on main: test_constructor_local[1], test_packed_sequences[1-1], test_roll_tensor_with_packed_sequences[1], test_fp8_support[False], test_fp8_support[True] (the last one exercises the FP8 + full-recompute path that [Dev] Fix MTP layer recompute  #4766 reported).
  • E2E FP8/MXFP8 + MTP + full recompute closure validation referenced in [Dev] Fix MTP layer recompute  #4766 (Qwen3.5 MXFP8 vs BF16 MTP loss gap closes; see the chart attached on that PR).

🤖 Generated with Claude Code

MultiTokenPredictionLayer._checkpointed_forward forwarded every kwarg
positionally to tensor_parallel.checkpoint. CheckpointFunction's
save_for_backward only accepts tensors and None, so a non-tensor
kwarg like packed_seq_params (PackedSeqParams) raised:

  TypeError: save_for_backward can only save variables, but argument
  N is of type PackedSeqParams

This made THD packed sequences + decoder full recompute
(--recompute-granularity full) unusable whenever MTP was enabled.

Mirror the closure pattern already used in
transformer_block._checkpointed_forward and attention.py: take
explicit named parameters; capture non-tensor objects (attention_bias,
inference_params, packed_seq_params) in a closure called
custom_forward; only forward tensor / None arguments positionally to
(te_checkpoint | tensor_parallel.checkpoint).

No semantic change on existing paths where all kwargs are tensor /
None (the closure is empty and call is identical).

Add unit test test_packed_sequences_with_full_recompute that fails
without this fix (same TypeError) and passes with it.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@copy-pr-bot

copy-pr-bot Bot commented May 2, 2026

Copy link
Copy Markdown

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@BestJuly BestJuly marked this pull request as ready for review May 2, 2026 23:58
@BestJuly BestJuly requested review from a team as code owners May 2, 2026 23:58
@svcnvidia-nemo-ci svcnvidia-nemo-ci requested a review from a team May 2, 2026 23:59
The previous refactor unconditionally pushed every kwarg through the
closure pattern. That uncovered a pre-existing latent issue on the
FP8 + recompute path: ``_proj_and_transformer_layer`` enters its
``fp8_autocast`` *inside* the function body, so by the time TE's
``_CheckpointFunction.forward`` reads ``is_fp8_enabled()`` to set
``_FP8_ACTIVATION_RECOMPUTE_ENABLED``, FP8 is not yet active. Phase 1
therefore never stashes the FP8 buffer position key, and phase 2's
``get_old_fp8_meta_tensors_for_recompute`` raises
``KeyError: global_fp8_buffer_pos_fwd_recompute``.

Historically this never surfaced because, with the OLD kwargs-passing
pattern, TE's ``_CheckpointFunction`` only tracks positional args via
``save_for_backward`` — kwargs tensors are stashed on ``ctx.kwargs``
unwired from autograd. As a result MTP's checkpoint backward never ran
and the FP8 buffer lookup never executed.

To fix the original ``PackedSeqParams`` save_for_backward crash without
disturbing the FP8 path, take the OLD route on FP8 (forward
``_proj_and_transformer_layer`` directly with kwargs into
``te_checkpoint``, which stashes non-tensor args natively) and use the
new closure pattern only on the non-FP8 path (where
``tensor_parallel.checkpoint`` requires tensor / None args).

This restores ``test_fp8_support[True]`` while keeping
``test_packed_sequences_with_full_recompute`` green.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…r_block

Two related fixes wrapped in one cleanup that matches
transformer_block._checkpointed_forward's outer/inner quantization
context design:

1. Pass differentiable activations (hidden_states, decoder_input) and
   all tensor / None args positionally to te_checkpoint, capturing the
   non-tensor objects (attention_bias, inference_params, packed_seq_params)
   via a custom_forward closure. TE's reentrant checkpoint only tracks
   positional tensor inputs as checkpoint inputs; kwarg tensors are
   dropped from the recompute backward path, silently detaching MTP
   activation gradients under fp8/mxfp8 + full recompute. Fixes the same
   regression as NVIDIA#4766 on top of the packed-sequence crash fix in this
   PR.

2. Enter the outer quantization context (get_fp8_context) only when
   fp8 + delayed scaling is active. TE's _CheckpointFunction.forward
   samples FP8GlobalStateManager.is_fp8_enabled() at entry to gate the
   phase-1 amax-buffer stash keyed by 'global_fp8_buffer_pos_fwd_recompute'
   used by phase-2 backward. Only delayed-scaling fp8 stashes/looks up
   that buffer (per TE's copy/get_old_fp8_meta_tensors_for_recompute);
   MXFP8 BlockScaling, Float8CurrentScaling, and NVFP4BlockScaling all
   treat it as a noop. Non-delayed quantized recipes rely on the inner
   context already entered inside _proj_and_transformer_layer.

3. Route fp4 through te_checkpoint as well (matches
   transformer_block.py:526). fp4 (NVFP4BlockScaling) is implemented on
   top of TE's fp8_autocast (see fp4_utils.get_fp4_context), so its
   recompute path needs te_checkpoint for correct handling. MTP
   _proj_and_transformer_layer itself still treats fp4 as a noop pending
   numerical validation; this change only prepares the recompute plumbing.

Co-Authored-By: Zhongbo Zhu <zhongboz@nvidia.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@BestJuly BestJuly force-pushed the lit/fix_mtp_thd_recompute_main branch from a2c7582 to 51114ff Compare May 13, 2026 02:51
@zhongbozhu zhongbozhu mentioned this pull request May 13, 2026
5 tasks
@svcnvidia-nemo-ci svcnvidia-nemo-ci added the Final Review PR is in the "final review" stage label May 15, 2026
@ericharper

Copy link
Copy Markdown
Contributor

/ok to test 51114ff

@ericharper ericharper enabled auto-merge May 15, 2026 20:10
@svcnvidia-nemo-ci svcnvidia-nemo-ci added Approved All necessary approvals have been made and removed Final Review PR is in the "final review" stage labels May 21, 2026
@ericharper ericharper added this pull request to the merge queue May 21, 2026
@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/26208774108

Merged via the queue into NVIDIA:main with commit 2b77d32 May 21, 2026
79 of 82 checks passed
copy-pr-bot Bot pushed a commit that referenced this pull request May 26, 2026
MultiTokenPredictionLayer.forward calls self._checkpointed_forward(
padding_mask=padding_mask, ...) (multi_token_prediction.py:1305), but
_checkpointed_forward and its inner custom_forward never accepted
padding_mask. With recompute_granularity == 'full' and self.training,
this raised:

    TypeError: MultiTokenPredictionLayer._checkpointed_forward() got
    an unexpected keyword argument 'padding_mask'

at multi_token_prediction.py:1301. The kwarg was introduced in #2645
on the call site; the _checkpointed_forward refactor in #4593 dropped
padding_mask from the recompute path.

Add padding_mask:
  * to _checkpointed_forward's signature
  * to custom_forward's signature so it flows into _proj_and_transformer_layer
  * positionally to te_checkpoint and tensor_parallel.checkpoint, matching the
    other tensor / None args (padding_mask is a rolled tensor, not a non-tensor
    closure-captured arg like attention_bias)
  * to the recompute_method == 'block' fallback that also calls
    _proj_and_transformer_layer directly

Also remove the @pytest.mark.flaky_in_dev markers from
test_forward_backward, test_fp8_support, and test_packed_sequences_with_full_recompute,
which were added in #4931 to mask this exact failure.

Closes #4933

Signed-off-by: oliver könig <okoenig@nvidia.com>
BestJuly pushed a commit to BestJuly/Megatron-LM that referenced this pull request May 26, 2026
MultiTokenPredictionLayer.forward calls self._checkpointed_forward(
padding_mask=padding_mask, ...) (multi_token_prediction.py:1305), but
_checkpointed_forward and its inner custom_forward never accepted
padding_mask. With recompute_granularity == 'full' and self.training,
this raised:

    TypeError: MultiTokenPredictionLayer._checkpointed_forward() got
    an unexpected keyword argument 'padding_mask'

at multi_token_prediction.py:1301. The kwarg was introduced in NVIDIA#2645
on the call site; the _checkpointed_forward refactor in NVIDIA#4593 dropped
padding_mask from the recompute path.

Add padding_mask:
  * to _checkpointed_forward's signature
  * to custom_forward's signature so it flows into _proj_and_transformer_layer
  * positionally to te_checkpoint and tensor_parallel.checkpoint, matching the
    other tensor / None args (padding_mask is a rolled tensor, not a non-tensor
    closure-captured arg like attention_bias)
  * to the recompute_method == 'block' fallback that also calls
    _proj_and_transformer_layer directly

Also remove the @pytest.mark.flaky_in_dev markers from
test_forward_backward, test_fp8_support, and test_packed_sequences_with_full_recompute,
which were added in NVIDIA#4931 to mask this exact failure.

Closes NVIDIA#4933

Signed-off-by: oliver könig <okoenig@nvidia.com>
janEbert pushed a commit to janEbert/Megatron-LM that referenced this pull request Jun 2, 2026
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
mathemakitten pushed a commit to mathemakitten/Megatron-LM that referenced this pull request Jun 12, 2026
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

26.06 Approved All necessary approvals have been made complexity: low

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants