Skip to content

Add absorbed-mla#3193

Merged
hxbai merged 5 commits into
NVIDIA:devfrom
kunlunl:absorbed_mla
Feb 13, 2026
Merged

Add absorbed-mla#3193
hxbai merged 5 commits into
NVIDIA:devfrom
kunlunl:absorbed_mla

Conversation

@kunlunl

@kunlunl kunlunl commented Feb 2, 2026

Copy link
Copy Markdown
Contributor

What does this PR do ?

PR for main #3198

Implement MLA with matrix absorption:

  • Absorbs K's up projection into Q: Q' = Q @ K_up_proj^T
  • Applies V's up projection after core attention
  • Core attention operates in MQA form with KV being single-head.

The absorption is mathematically equivalent to standard MLA but enables MQA-style attention which
can be more efficient for certain attention variants.


1. TL;DR

  • What: This PR adds a new AbsorbedMLASelfAttention class that implements MLA with matrix absorption optimization, where K's up-projection is absorbed into Q before core attention, enabling MQA-style computation.
  • Why: Matrix absorption allows MLA to operate in a more memory-efficient MQA form where KV becomes single-head, enabling more efficient implementation for certain attention variants.

2. Big Picture

graph TB
    subgraph "Before: Standard MLA"
        A1[hidden_states] --> B1[Q down proj]
        A1 --> C1[KV down proj]
        B1 --> D1[Q up proj]
        C1 --> E1[KV up proj<br/>joint K+V]
        D1 --> F1[RoPE on Q]
        E1 --> G1[RoPE on K]
        E1 --> H1[V extraction]
        F1 --> I1[Core Attention<br/>MHA: n heads for Q,K,V]
        G1 --> I1
        H1 --> I1
        I1 --> J1[linear_proj]
    end
    
    subgraph "After: Absorbed MLA"
        A2[hidden_states] --> B2[Q down proj]
        A2 --> C2[KV down proj]
        B2 --> D2[Q up proj]
        C2 --> E2[K up weight<br/>NOT applied to KV]
        D2 --> F2["Q absorbed = Q @ K_up^T"]
        E2 --> F2
        F2 --> G2[RoPE on Q_absorbed]
        C2 --> H2[Compressed KV<br/>single head]
        H2 --> I2[RoPE on K_pos]
        G2 --> J2["Core Attention<br/>MQA: n heads Q, 1 head KV"]
        I2 --> J2
        J2 --> K2["V up proj AFTER attention"]
        K2 --> L2[linear_proj]
    end
    
    style F2 fill:#90EE90
    style K2 fill:#90EE90
    style J2 fill:#FFD700
Loading

3. Design Rationale

3.1 Problem Background

Standard MLA (Multi-Latent Attention) as implemented in MLASelfAttention works as follows:

# Standard MLA flow in MLASelfAttention.get_query_key_value_tensors()
kv, _ = self.linear_kv_up_proj(kv_compressed)  # [s, b, n*(qk_head_dim + v_head_dim)]
kv = kv.view(..., self.num_attention_heads_per_partition, qk_head_dim + v_head_dim)
k_no_pe, value = torch.split(kv, [qk_head_dim, v_head_dim], dim=-1)
# Core attention receives: Q[s,b,n,d], K[s,b,n,d], V[s,b,n,d] - all multi-head

Limitation: The KV up-projection expands to all n attention heads before core attention, which cannot leverage specialized MQA kernels that are optimized for single-head KV.

3.2 Solution Approach

Matrix Absorption exploits the mathematical equivalence:

Standard: Attention(Q, K @ K_up_proj, V @ V_up_proj)
Absorbed: Attention(Q @ K_up_proj^T, compressed_K, compressed_V) @ V_up_proj

Key insight: Instead of projecting K and V up to multi-head before attention, we can:

  1. Absorb K_up_proj into Q: Q' = Q @ K_up_proj^T (still multi-head)
  2. Keep KV compressed during attention (single-head)
  3. Apply V_up_proj after attention

3.3 Key Design Points

New Classes and Responsibilities:

  1. AbsorbedMLASelfAttentionSubmodules (dataclass):

    • Same as MLASelfAttentionSubmodules but with separate linear_k_up_proj and linear_v_up_proj instead of fused linear_kv_up_proj
  2. AbsorbedMLASelfAttention (class):

    • Extends base Attention class
    • Key differences from MLASelfAttention:
      • Stores K and V up-projection weights separately
      • Implements absorption in get_query_key_value_tensors()
      • Applies V up-projection after core attention in forward()

Interface Contracts:

# Core attention in AbsorbedMLA receives different shapes:
core_attention(
    q_absorbed,    # [s, b, n, kv_lora_rank + qk_pos_emb_head_dim] - multi-head
    kv_compressed, # [s, b, 1, kv_lora_rank + qk_pos_emb_head_dim] - SINGLE head
    v=None,        # V is not passed; absorbed impl applies V_up after attention
    ...
)

4. Execution Path Deep Dive

4.1 Entry Point

The absorbed MLA is triggered when a transformer layer is configured with AbsorbedMLASelfAttention:

# In model spec configuration (not shown in this PR, but expected usage):
layer_spec = TransformerLayerSpec(
    self_attention=ModuleSpec(
        module=AbsorbedMLASelfAttention,
        submodules=AbsorbedMLASelfAttentionSubmodules(
            linear_k_up_proj=...,
            linear_v_up_proj=...,
            # ... other submodules
        )
    )
)

4.2 Call Chain Visualization

sequenceDiagram
    participant TL as TransformerLayer
    participant ABS as AbsorbedMLASelfAttention
    participant QKV as get_query_key_value_tensors()
    participant CA as core_attention
    participant VP as V up-projection
    participant LP as linear_proj
    
    TL->>ABS: forward(hidden_states, attention_mask)
    ABS->>QKV: get_query_key_value_tensors(hidden_states)
    
    Note over QKV: Q down proj → layernorm → Q up proj
    Note over QKV: KV down proj → layernorm
    Note over QKV: K_up_weight absorbed into Q
    Note over QKV: Apply RoPE
    
    QKV-->>ABS: q_absorbed, kv_compressed, q_compressed
    
    ABS->>CA: core_attention(q_absorbed, kv_compressed, v=None)
    Note over CA: MQA-style: Q multi-head, KV single-head
    CA-->>ABS: attn_out [s, b, n * kv_lora_rank]
    
    ABS->>VP: einsum("...nc,ndc->...nd", attn_out, v_up_weight)
    Note over VP: Project to full value dimension
    VP-->>ABS: attn_out [s, b, n * v_head_dim]
    
    ABS->>LP: linear_proj(attn_out)
    LP-->>ABS: output [s, b, hidden_size]
    
    ABS-->>TL: output, bias
Loading

4.3 Data Flow

graph TD
    A["hidden_states<br/>[s, b, hidden_size]"] -->|Q down proj| B["q_compressed<br/>[s, b, q_lora_rank]"]
    A -->|KV down proj| C["kv_combined<br/>[s, b, kv_lora_rank + qk_pos_emb]"]
    
    B -->|layernorm| D["q_compressed_norm<br/>[s, b, q_lora_rank]"]
    C -->|split| E["kv_compressed<br/>[s, b, kv_lora_rank]"]
    C -->|split| F["k_pos_emb<br/>[s, b, qk_pos_emb]"]
    
    E -->|layernorm| G["kv_compressed_norm<br/>[s, b, kv_lora_rank]"]
    
    D -->|Q up proj| H["q<br/>[s, b, n, qk_head_dim + qk_pos_emb]"]
    
    H -->|split| I["q_no_pe<br/>[s, b, n, qk_head_dim]"]
    H -->|split| J["q_pos_emb<br/>[s, b, n, qk_pos_emb]"]
    
    I -->|"einsum(q, k_up_weight)"| K["q_absorbed<br/>[s, b, n, kv_lora_rank]"]
    
    J -->|apply RoPE| L["q_pos_emb_rope<br/>[s, b, n, qk_pos_emb]"]
    F -->|apply RoPE| M["k_pos_emb_rope<br/>[s, b, 1, qk_pos_emb]"]
    
    K -->|concat| N["q_final<br/>[s, b, n, kv_lora_rank + qk_pos_emb]"]
    L --> N
    
    G -->|concat| O["kv_final<br/>[s, b, 1, kv_lora_rank + qk_pos_emb]"]
    M --> O
    
    N -->|core_attention| P["attn_out<br/>[s, b, n, kv_lora_rank]"]
    O --> P
    
    P -->|"einsum(out, v_up_weight)"| Q["projected<br/>[s, b, n, v_head_dim]"]
    
    Q -->|reshape + linear_proj| R["output<br/>[s, b, hidden_size]"]
    
    style K fill:#90EE90
    style Q fill:#90EE90
Loading

4.4 Core Code Walkthrough

4.4.1 K Absorption into Q

def qkv_up_proj_and_rope_apply(q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb):
    # Step 1: Apply Q up projection
    # Why: Transform compressed Q to full attention head dimension
    if self.config.q_lora_rank is not None:
        q, _ = self.linear_q_up_proj(q_compressed)  # [num_tokens, n * q_head_dim]
    else:
        q, _ = self.linear_q_proj(q_compressed)
    
    # Step 2: Reshape Q for per-head processing
    # q: [num_tokens, n, qk_head_dim + qk_pos_emb_head_dim]
    q = q.view(*q.size()[:-1], self.num_attention_heads_per_partition, self.q_head_dim)
    
    # Step 3: Prepare K up-projection weight for absorption
    # Key insight: We don't apply this to KV, but absorb into Q instead
    # k_up_weight shape: [n, qk_head_dim, kv_lora_rank]
    k_up_weight = self.linear_k_up_proj.weight.view(
        self.num_attention_heads_per_partition,
        self.config.qk_head_dim,
        self.config.kv_lora_rank,
    )
    
    # Step 4: Split Q into non-positional and positional parts
    q_no_pe, q_pos_emb = torch.split(
        q, [self.config.qk_head_dim, self.config.qk_pos_emb_head_dim], dim=-1
    )
    
    # Step 5: THE ABSORPTION - multiply Q by K's up-projection weight
    # Mathematical equivalence: Q @ K = (Q @ K_up^T) @ K_compressed
    # q_absorbed: [num_tokens, n, kv_lora_rank]
    q_absorbed = torch.einsum("...nd,ndk->...nk", q_no_pe, k_up_weight)
    q_absorbed = q_absorbed.contiguous()
    
    # Step 6: Apply RoPE to positional components
    q_pos_emb = apply_rotary_pos_emb(q_pos_emb, rotary_pos_emb, ...)
    k_pos_emb = apply_rotary_pos_emb(k_pos_emb, rotary_pos_emb, ...)
    
    # Step 7: Combine absorbed Q with positional embeddings
    # Final Q: [num_tokens, n, kv_lora_rank + qk_pos_emb_head_dim]
    q_absorbed = torch.cat([q_absorbed, q_pos_emb], dim=-1)
    # KV remains single-head: [num_tokens, 1, kv_lora_rank + qk_pos_emb_head_dim]
    kv_compressed = torch.cat([kv_compressed, k_pos_emb], dim=-1)
    
    return q_absorbed, kv_compressed

4.4.2 V Up-Projection After Attention

def forward(self, hidden_states, attention_mask, ...):
    # ... QKV computation ...
    
    # Core attention with MQA-style inputs
    core_attn_out = self.core_attention(
        q_absorbed,      # [s, b, n, kv_lora_rank + qk_pos_emb]
        kv_compressed,   # [s, b, 1, kv_lora_rank + qk_pos_emb] - SINGLE HEAD
        v=None,          # V is not provided; we apply V_up after attention
        ...
    )
    
    # ==================================
    # Apply V up projection AFTER core attention
    # ==================================
    # v_up_weight shape: [n, v_head_dim, kv_lora_rank]
    v_up_weight = self.linear_v_up_proj.weight.view(
        self.num_attention_heads_per_partition, 
        self.config.v_head_dim, 
        self.config.kv_lora_rank
    )
    
    # Reshape attention output for einsum
    # core_attn_out: [s, b, n * kv_lora_rank] -> [s, b, n, kv_lora_rank]
    core_attn_out = core_attn_out.view(
        *core_attn_out.shape[:-1],
        self.num_attention_heads_per_partition,
        self.config.kv_lora_rank,
    )
    
    # Apply V up-projection: [s, b, n, kv_lora_rank] @ [n, v_head_dim, kv_lora_rank]^T
    # Result: [s, b, n, v_head_dim]
    core_attn_out = torch.einsum("...nc,ndc->...nd", core_attn_out, v_up_weight)
    core_attn_out = core_attn_out.contiguous()
    
    # Flatten for linear_proj: [s, b, n * v_head_dim]
    core_attn_out = core_attn_out.view(*core_attn_out.shape[:-2], -1)
    
    # Output projection
    output, bias = self.linear_proj(core_attn_out)
    return output, bias

4.4.3 Checkpoint Loading with Weight Conversion

def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
    """Handle loading from checkpoints with combined KV up projection weights.
    
    Standard MLA stores: linear_kv_up_proj.weight with interleaved K,V per head
        Layout: [head0_K, head0_V, head1_K, head1_V, ...]
    
    AbsorbedMLA needs: separate linear_k_up_proj and linear_v_up_proj
        K: [head0_K, head1_K, ...]
        V: [head0_V, head1_V, ...]
    """
    combined_key = f'{prefix}linear_kv_up_proj.weight'
    k_up_key = f'{prefix}linear_k_up_proj.weight'
    v_up_key = f'{prefix}linear_v_up_proj.weight'
    
    if combined_key in state_dict:
        combined_weight = state_dict[combined_key]
        
        # Split with proper per-head de-interleaving
        k_weight, v_weight = self._split_kv_weights(combined_weight)
        
        state_dict[k_up_key] = k_weight
        state_dict[v_up_key] = v_weight
        del state_dict[combined_key]
    
    super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)

5. Risks & Edge Cases

Risk Category Specific Concern Mitigation/Status
Correctness Mathematical equivalence to standard MLA ✅ Tested via cosine similarity > 0.9999 in test_absorbed_mla.py
Correctness Gradient equivalence during training ✅ Tested gradient comparison in unit tests
Correctness TP+CP combination correctness ✅ Tested with tp_cp = [[1,1], [2,1], [1,2], [2,2]]
Correctness Sequence packing (THD format) ✅ Tested with qkv_format = ['sbhd', 'thd']
Compatibility Checkpoint loading from standard MLA ✅ Handled via _load_from_state_dict() weight conversion
Performance Extra einsum operations ⚠️ Needs profiling; may be offset by MQA kernel benefits
Feature Gap Inference not supported assert inference_context is None
Feature Gap FP8/FP4 not supported assert not quantization
Feature Gap cache_mla_latents not supported assert not self.cache_mla_latents
Feature Gap QK clipping not implemented raise NotImplementedError("clip_qk")
Edge Case column_parallel vs duplicated down_proj ✅ Both tested via down_proj_use_column_parallel param

Contribution process

flowchart LR
    A[Pre-checks] --> B[PR Tests]
    subgraph Code Review/Approval
        C1[Expert Review] --> C2[Final Review]
    end
    B --> C1
    C2 --> D[Merge]
Loading

Pre-checks

  • I want this PR in a versioned release and have added the appropriate Milestone (e.g., Core 0.8)
  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

The following process is enforced via the CODEOWNERS file for changes into megatron/core. For changes outside of megatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.

For MRs into `main` branch

Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

(Step 1): Add PR label Expert Review

(Step 2): Collect the expert reviewers reviews

  1. Attach the Expert Review label when your PR is ready for review.
  2. GitHub auto-assigns expert reviewers based on your changes. They will get notified and pick up your PR soon.

⚠️ Only proceed to the next step once all reviewers have approved, merge-conflict are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

(Step 3): Final Review

  1. Add Final Review label
  2. GitHub auto-assigns final reviewers based on your changes. They will get notified and pick up your PR soon.

(Optional Step 4): Cherry-pick into release branch

If this PR also needs to be merged into core_r* release branches, after this PR has been merged, select Cherry-pick to open a new PR into the release branch.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

Merging your PR

Any member of core-adlr and core-nemo will be able to merge your PR.

@kunlunl kunlunl requested review from a team as code owners February 2, 2026 07:38
@copy-pr-bot

copy-pr-bot Bot commented Feb 2, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@kunlunl kunlunl mentioned this pull request Feb 2, 2026
6 tasks
@kunlunl

kunlunl commented Feb 2, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 1374fb9

@ko3n1g ko3n1g added this to the Core 0.16 milestone Feb 2, 2026
@dimapihtar dimapihtar added complexity: high dev branch Dev branch related issues and development labels Feb 2, 2026
@kunlunl kunlunl mentioned this pull request Feb 3, 2026
6 tasks
@kunlunl

kunlunl commented Feb 3, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 2670a9b

@kunlunl

kunlunl commented Feb 3, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 3f344ac

@kunlunl

kunlunl commented Feb 5, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 472808d

layernorm_zero_centered_gamma=False,
expert_model_parallel_size=1,
tensor_model_parallel_size=tensor_model_parallel_size,
sequence_parallel=tensor_model_parallel_size > 1,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does absorbed MLA support TP > 1 without SP for now? If yes, please add at least one test case to cover it. If no, please add an assertion in TransformerConfig or AbsorbedMLA.__init__.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because this PR only adds the AbsorbedMLA but doesn't change the spec, so there is no config to enable this AbsorbedMLA directly. I added assertion in the AbsorbedMLA class instead.

if self.recompute_up_proj:
quantization = self.config.fp8 or self.config.fp4
assert not quantization, "FP8/FP4 is not supported for AbsorbedMLA"
self.qkv_up_checkpoint = tensor_parallel.CheckpointWithoutOutput(fp8=quantization)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing

from megatron.core import tensor_parallel

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe need a UT to cover recompute. I think it is not emergent, so you can mark as a TODO.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added missing import, also added test case to guard it.


def get_absorbed_mla_submodules(
down_proj_use_column_parallel: bool, qk_layernorm: bool, rms_norm: bool
) -> AbsorbedMLASelfAttentionSubmodules:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about using the layer spec function in core.model.gpt_layer_spec to simplify the code and cover the real spec function in the UT? As it has not been supported yet, we can mark it as a TODO here.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR only adds the AbsorbedMLA but doesn't change the spec, because the current unfused dsa doesn't support absorption yet. Will introduce new dsa in the next PR, in that PR I'll do what you said.

Added TODO comments.

@yuzhongw-nvidia yuzhongw-nvidia left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your great work! Overall LGTM. I've left a few minor comments.

kv_layernorm: Union[ModuleSpec, type] = None


class AbsorbedMLASelfAttention(Attention):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are some duplicated code with MLASelfAttention. Was this done on purpose because it's an experimental variant? Does it make sense to subclass MLASelfAttention?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's done on purpose.

We will change/optimize this class a lot in the near future, and want to iterate faster without adding risks to standard MLA which is already used in some production scenario.

We plan to merge it with the standard MLA or make it a subclass of MLA when this feature is stable.

assert (
packed_seq_params.local_cp_size is None
), "dynamic context parallel is not supported with MLA yet and is planned for future. \
Please disable dynamic context parallel."

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MLASelfAttention has this as "hybrid context parallel"
Maybe we need to cherrypick the recent changes in multi_latent_attention.py?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CP is not supported yet, and it's not easy because we need some new CP solution for the attention variant like dsa. We will add this hybrid context parallel feature when we add CP to it.

@kunlunl

kunlunl commented Feb 12, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 1339a7b

@hxbai hxbai added this pull request to the merge queue Feb 13, 2026
@ko3n1g

ko3n1g commented Feb 13, 2026

Copy link
Copy Markdown
Contributor

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/21981893832

@github-merge-queue github-merge-queue Bot removed this pull request from the merge queue due to failed status checks Feb 13, 2026
@hxbai hxbai added this pull request to the merge queue Feb 13, 2026
@ko3n1g

ko3n1g commented Feb 13, 2026

Copy link
Copy Markdown
Contributor

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/21990716575

Merged via the queue into NVIDIA:dev with commit 6059f36 Feb 13, 2026
74 of 78 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

complexity: high dev branch Dev branch related issues and development

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants