Skip to content

Fix MLA dynamic inference decode flag#4902

Open
cuichenx wants to merge 1 commit into
NVIDIA:mainfrom
cuichenx:chcui/fix-mla-dynamic-inference-is-decode-only
Open

Fix MLA dynamic inference decode flag#4902
cuichenx wants to merge 1 commit into
NVIDIA:mainfrom
cuichenx:chcui/fix-mla-dynamic-inference-is-decode-only

Conversation

@cuichenx

Copy link
Copy Markdown
Contributor

Summary

Fix the dynamic inference MLA path to pass inference_context.is_decode_only() into flash_decode_and_prefill().

The regular attention dynamic path already passes this argument, but MultiLatentAttention omitted it after the flash_decode_and_prefill() signature gained is_decode_only.

Fixes #4901.

Testing

  • python -m py_compile megatron/core/transformer/multi_latent_attention.py
  • git diff --check

Runtime validation for MLA dynamic inference with cached MLA latents is pending.

Signed-off-by: Chen Cui <chcui@nvidia.com>
@copy-pr-bot

copy-pr-bot Bot commented May 20, 2026

Copy link
Copy Markdown

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@cuichenx cuichenx marked this pull request as ready for review May 20, 2026 23:05
@cuichenx cuichenx requested review from a team as code owners May 20, 2026 23:05
@svcnvidia-nemo-ci svcnvidia-nemo-ci requested a review from a team May 20, 2026 23:06
@cuichenx

Copy link
Copy Markdown
Contributor Author

/claude strict-review

@cuichenx

Copy link
Copy Markdown
Contributor Author

/ok to test aa8b233

@claude claude Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strict review passed — no significant issues found. LGTM.

CRITICAL: 0 | IMPORTANT: 0 | SUGGESTION: 0

The fix is correct: flash_decode_and_prefill() (defined at attention.py:829) requires is_decode_only as its last positional argument. The regular Attention path already passes it (attention.py:1304), but the MultiLatentAttention path omitted it — causing a TypeError whenever cache_mla_latents was enabled with dynamic batching. This one-line addition aligns the MLA call site with both the function signature and the existing non-MLA call site.

Minimal risk — only enables a previously broken code path.

@santhnm2

Copy link
Copy Markdown
Contributor

@cuichenx Can we add a unit test for MLA which will exercise this code path?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Tracking] Dynamic inference: MLA flash_decode_and_prefill call missing is_decode_only

4 participants