Fix MLA dynamic inference decode flag#4902
Conversation
Signed-off-by: Chen Cui <chcui@nvidia.com>
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
/claude strict-review |
|
/ok to test aa8b233 |
There was a problem hiding this comment.
Strict review passed — no significant issues found. LGTM.
CRITICAL: 0 | IMPORTANT: 0 | SUGGESTION: 0
The fix is correct: flash_decode_and_prefill() (defined at attention.py:829) requires is_decode_only as its last positional argument. The regular Attention path already passes it (attention.py:1304), but the MultiLatentAttention path omitted it — causing a TypeError whenever cache_mla_latents was enabled with dynamic batching. This one-line addition aligns the MLA call site with both the function signature and the existing non-MLA call site.
Minimal risk — only enables a previously broken code path.
|
@cuichenx Can we add a unit test for MLA which will exercise this code path? |
Summary
Fix the dynamic inference MLA path to pass
inference_context.is_decode_only()intoflash_decode_and_prefill().The regular attention dynamic path already passes this argument, but
MultiLatentAttentionomitted it after theflash_decode_and_prefill()signature gainedis_decode_only.Fixes #4901.
Testing
python -m py_compile megatron/core/transformer/multi_latent_attention.pygit diff --checkRuntime validation for MLA dynamic inference with cached MLA latents is pending.