Skip to content

[Tracking] Dynamic inference: MLA flash_decode_and_prefill call missing is_decode_only #4901

@cuichenx

Description

@cuichenx

This is a tracking-only issue for #4697

Summary

Dynamic inference for MLA-based models that use cached MLA latents can fail at runtime in MultiLatentAttention before token generation completes.

This appears to be an interface mismatch inside MCore: Attention.flash_decode_and_prefill() requires an is_decode_only argument, and the regular attention dynamic inference path passes it, but the MLA dynamic inference path does not.

Affected class of models / path

  • Models using Multi-Latent Attention (MLA)
  • Dynamic batching inference
  • cache_mla_latents=True
  • 64-token inference blocks, as required by the dynamic MLA latent-cache path

Observed failure

TypeError: Attention.flash_decode_and_prefill() missing 1 required positional argument: 'is_decode_only'

The failure occurs during dynamic generation after model construction and checkpoint import/load have succeeded.

Code pointers

At MCore SHA observed locally:

a3aade9d2fb6ced6e49414e24d892ea680035cb7

Attention.flash_decode_and_prefill() requires is_decode_only:

def flash_decode_and_prefill(..., block_table, is_decode_only):

The regular attention dynamic path passes the argument:

core_attn_out = self.flash_decode_and_prefill(
    q,
    k,
    v,
    max_seqlen_q,
    max_seqlen_k,
    cu_query_lengths,
    cu_kv_lengths,
    kv_lengths,
    block_table,
    inference_context.is_decode_only(),
)

The MLA dynamic path currently omits the final argument:

core_attn_out = self.flash_decode_and_prefill(
    q,
    k,
    v,
    max_seqlen_q,
    max_seqlen_k,
    cu_query_lengths,
    cu_kv_lengths,
    kv_lengths,
    block_table,
)

Suggested fix

Update the MultiLatentAttention dynamic batching path to pass inference_context.is_decode_only() into flash_decode_and_prefill():

core_attn_out = self.flash_decode_and_prefill(
    q,
    k,
    v,
    max_seqlen_q,
    max_seqlen_k,
    cu_query_lengths,
    cu_kv_lengths,
    kv_lengths,
    block_table,
    inference_context.is_decode_only(),
)

Validation request

Please validate dynamic inference for an MLA model with cache_mla_latents=True, covering both prefill and decode. The expected result is that generation reaches produced tokens without the flash_decode_and_prefill() signature error.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions