Skip to content

DeepSeek V3.2 support#2440

Merged
ericharper merged 15 commits into
NVIDIA:mainfrom
kunlunl:kunlunl/deepseek_v3.2_main
Jan 16, 2026
Merged

DeepSeek V3.2 support#2440
ericharper merged 15 commits into
NVIDIA:mainfrom
kunlunl:kunlunl/deepseek_v3.2_main

Conversation

@kunlunl

@kunlunl kunlunl commented Dec 1, 2025

Copy link
Copy Markdown
Contributor

DeepSeek V3.2 Sparse Attention Support

dev PR: #2154

1. TL;DR

What: This PR adds support for DeepSeek V3.2-style sparse attention (DSA) to Megatron-LM, enabling models to use learned attention sparsity patterns via a lightweight indexer module.

Why: Dense attention has O(n²) complexity which becomes computationally prohibitive for long sequences. DSA reduces this by learning to predict which key-value pairs are most relevant for each query, allowing the model to attend to only top-k tokens instead of all tokens.

Impact: Users can now train models with DeepSeek V3.2's sparse attention mechanism in Megatron-LM, which combines Multi-Latent Attention with a trainable sparse indexer.


2. Big Picture

2.1 Before vs After Architecture

graph TB
    subgraph "Before: Standard Dense Attention"
        A1[Hidden States] --> B1[QKV Projection]
        B1 --> C1[Dense Attention]
        C1 --> D1[Output Projection]
    end
    
    subgraph "After: DeepSeek Sparse Attention"
        A2[Hidden States] --> B2[QKV Projection]
        A2 --> E2[DSA Indexer]
        B2 --> E2
        E2 --> F2[Top-K Selection]
        B2 --> G2[Sparse Attention]
        F2 --> G2
        G2 --> D2[Output Projection]
        E2 --> H2[Indexer Loss]
        H2 --> D2
    end
    
    style E2 fill:#90EE90
    style F2 fill:#90EE90
    style H2 fill:#FFD700
Loading

Key Changes:

  • NEW: DSA Indexer module that learns to predict important tokens
  • NEW: Sparse attention module that only computes attention for top-k tokens
  • NEW: KL divergence auxiliary loss to train the indexer
  • MODIFIED: Multi-Latent Attention to support DSA variant

2.2 Change Scope Summary

Category Files Description
New Core Module megatron/core/transformer/experimental_attention_variant/dsa.py DSA indexer, DSA sparse attention module, loss computation
New Spec File megatron/core/models/gpt/experimental_attention_variant_module_specs.py Module specs for attention variants
New Test tests/unit_tests/transformer/test_attention_variant_dsa.py Comprehensive DSA unit tests
Modified Core megatron/core/transformer/multi_latent_attention.py MLA integration with DSA
Modified Config megatron/core/transformer/transformer_config.py Added DSA config parameters
Modified Args megatron/training/arguments.py CLI arguments for DSA
Modified Training megatron/training/training.py Loss logging for indexer
Modified Specs megatron/core/models/gpt/gpt_layer_specs.py Renamed linear_attention → experimental_attention_variant
Modified Builder gpt_builders.py Updated to use new attention variant system

3. Key Design Points

Core Abstractions Introduced:

  1. DSAIndexer: Computes index scores to identify top-k most relevant tokens

    • Input: Hidden states x [seqlen, batch, hidden_size] + compressed query qr [seqlen, batch, q_lora_rank]
    • Output: Top-k indices [batch, seqlen, index_topk]
    • Uses its own small transformer-like architecture with Q/K projections + RoPE + Hadamard rotation
  2. DSAttention: Sparse attention mechanism using indexer outputs

    • Wraps DSAIndexer and applies sparse attention kernel
    • Attaches KL divergence loss to train indexer
  3. DSAIndexerLossAutoScaler: Custom autograd function

    • Allows indexer loss to backpropagate independently of main loss
    • Scales indexer loss gradient separately

Interface Contracts:

# DSAIndexer.forward
def forward(x, qr, mask=None, packed_seq_params=None) -> topk_indices
    """
    x: [seqlen, batch, hidden_size] - Main hidden states (DETACHED)
    qr: [seqlen, batch, q_lora_rank] - Compressed query (DETACHED)
    mask: [batch, seqlen, seqlen] - Attention mask (FP32 with -inf for masked positions)
    
    Returns: [batch, seqlen, index_topk] - Indices of top-k tokens to attend to
    """

# DSAttention.forward
def forward(query, key, value, x, qr, attention_mask, ...) -> output
    """
    query: [sq, b, np, hn] - Full query tensor from MLA
    key: [sk, b, np, hn] - Full key tensor from MLA
    value: [sk, b, np, hnv] - Full value tensor from MLA
    x: [sq, b, hidden_size] - Original hidden states for indexer
    qr: [sq, b, q_lora_rank] - Compressed query for indexer
    
    Returns: [sq, b, hidden_size] - Attention output with indexer loss attached
    """

Important Invariants:

  • Indexer inputs (x, qr) are always detached - gradients don't flow back to main model
  • Indexer loss is attached via DSAIndexerLossAutoScaler.apply() - backpropagates separately
  • Top-k selection uses masked index scores (causal mask applied before topk)
  • DSA currently requires multi_latent_attention=True and context_parallel_size=1

4. Execution Path Deep Dive

4.1 Entry Point

DSA is triggered when creating a GPT model with --experimental-attention-variant dsa flag:

# Entry: gpt_builders.py::gpt_builder()
def gpt_builder(args, pre_process, post_process, vp_stage=None, config=None):
    # ...
    linear_attention_variants = ["gated_delta_net"]
    if args.num_experts or args.experimental_attention_variant in linear_attention_variants:
        transformer_layer_spec = get_gpt_decoder_block_spec(...)  # Uses MoE path
    elif:
        # ...
    else:
        transformer_layer_spec = _get_transformer_layer_spec(
            # ...
            experimental_attention_variant=args.experimental_attention_variant,  # 'dsa'
            # ...
        )

4.2 Data Flow

graph TD
    A["Input: hidden_states<br/>[sq, b, hidden]"] --> B["MLA Q Compression<br/>linear_q_proj→linear_q_down_proj"]
    A --> C["MLA KV Compression<br/>linear_kv_down_proj"]
    
    B --> D["q_compressed<br/>[sq, b, q_lora_rank]"]
    C --> E["kv_compressed<br/>[sq, b, kv_lora_rank]"]
    
    D --> F["MLA Q Upsampling<br/>linear_q_up_proj + RoPE"]
    E --> G["MLA KV Upsampling<br/>linear_kv_up_proj + RoPE"]
    
    F --> H["query<br/>[sq, b, np, hn]"]
    G --> I["key<br/>[sk, b, np, hn]"]
    G --> J["value<br/>[sk, b, np, hnv]"]
    
    A --> K["x.detach()"]
    D --> L["q_compressed.detach()"]
    
    K --> M["DSAIndexer"]
    L --> M
    M --> N["Indexer Q Proj<br/>linear_wq_b<br/>[sq, b, index_n_heads, index_head_dim]"]
    K --> O["Indexer K Proj<br/>linear_wk + k_norm<br/>[sk, b, index_head_dim]"]
    K --> P["Indexer Weights<br/>linear_weights_proj<br/>[sq, b, index_n_heads]"]
    
    N --> Q["Apply RoPE"]
    O --> R["Apply RoPE"]
    Q --> S["rotate_activation<br/>&#40;Hadamard transform&#41;"]
    R --> T["rotate_activation<br/>&#40;Hadamard transform&#41;"]
    
    S --> U["Index Scores<br/>q @ k^T → ReLU → weighted sum<br/>[b, sq, sk]"]
    T --> U
    P --> U
    
    U --> V["TopK Selection<br/>[b, sq, index_topk]"]
    
    H --> W["Sparse Attention"]
    I --> W
    J --> W
    V --> W
    
    W --> X["attention_output<br/>[sq, b, hidden]"]
    
    U --> Y["KL Divergence Loss<br/>KL&#40;true_attn || index_scores&#41;"]
    V --> Y
    H --> Y
    I --> Y
    
    Y --> Z["indexer_loss<br/>scalar"]
    
    X --> AA["DSAIndexerLossAutoScaler.apply"]
    Z --> AA
    AA --> AB["Final Output<br/>&#40;with loss attached&#41;"]
    
    style K fill:#FFE4B5
    style L fill:#FFE4B5
    style M fill:#90EE90
    style U fill:#87CEEB
    style V fill:#87CEEB
    style W fill:#FFD700
    style Y fill:#FF6347
    style Z fill:#FF6347
    style AA fill:#DDA0DD
Loading

5. Module Relationships

classDiagram
    class TransformerConfig {
        +int num_layers
        +int hidden_size
        +str experimental_attention_variant
    }
    
    class MLATransformerConfig {
        +int q_lora_rank
        +int kv_lora_rank
        +int dsa_indexer_n_heads
        +int dsa_indexer_head_dim
        +int dsa_indexer_topk
        +float dsa_indexer_loss_coeff
    }
    
    class Attention {
        <<abstract>>
        +forward()*
    }
    
    class MultiLatentAttention {
        +get_query_key_value_tensors()
        +forward()
    }
    
    class MLASelfAttention {
        +linear_q_proj
        +linear_kv_down_proj
        +core_attention
        +get_query_key_value_tensors(return_compressed_tensors)
    }
    
    class DSAttention {
        +indexer: DSAIndexer
        +softmax_scale: float
        +forward(query, key, value, x, qr, ...)
    }
    
    class DSAIndexer {
        +linear_wq_b
        +linear_wk
        +k_norm
        +linear_weights_proj
        +rotary_pos_emb
        +forward(x, qr, mask)
        +forward_with_scores(x, qr, mask)
        -_apply_rope()
        -_compute_index_scores()
    }
    
    class DSAIndexerSubmodules {
        +linear_wq_b: ModuleSpec
        +linear_wk: ModuleSpec
        +k_norm: ModuleSpec
        +linear_weights_proj: ModuleSpec
    }
    
    class DSAttentionSubmodules {
        +indexer: ModuleSpec
    }
    
    class MLASelfAttentionSubmodules {
        +core_attention: ModuleSpec
        +linear_q_proj
        +linear_kv_down_proj
        +q_layernorm
        +kv_layernorm
    }
    
    class RotaryEmbedding {
        +forward(seq_len)
    }
    
    class DSAIndexerLossAutoScaler {
        <<autograd.Function>>
        +forward(output, loss)$
        +backward(grad_output)$
        +set_loss_scale(scale)$
        +main_loss_backward_scale$
    }
    
    class DSAIndexerLossLoggingHelper {
        +save_loss_to_tracker()$
        +reduce_loss_in_tracker()$
        +track_indexer_metrics()$
        +tracker: dict$
    }
    
    TransformerConfig <|-- MLATransformerConfig : extends
    Attention <|-- MultiLatentAttention : extends
    MultiLatentAttention <|-- MLASelfAttention : extends
    Attention <|-- DSAttention : implements (core_attention)
    
    MLASelfAttention --> DSAttention : uses as core_attention
    MLASelfAttention --> MLASelfAttentionSubmodules : configured by
    DSAttention --> DSAIndexer : contains
    DSAttention --> DSAttentionSubmodules : configured by
    DSAIndexer --> DSAIndexerSubmodules : configured by
    DSAIndexer --> RotaryEmbedding : uses
    DSAttention ..> DSAIndexerLossAutoScaler : uses
    DSAttention ..> DSAIndexerLossLoggingHelper : logs to
    
    MLASelfAttention ..> MLATransformerConfig : reads config
    DSAttention ..> MLATransformerConfig : reads config
    DSAIndexer ..> MLATransformerConfig : reads config
Loading

Key Relationships:

  1. Composition:

    • MLASelfAttention contains DSAttention as its core_attention module
    • DSAttention contains DSAIndexer for computing sparse indices
  2. Utility Classes:

    • DSAIndexerLossAutoScaler: Custom autograd for loss attachment
    • DSAIndexerLossLoggingHelper: Singleton for collecting losses across layers

New Dependencies Introduced:

  • fast_hadamard_transform (optional): For Hadamard rotation activation
  • Fallback: Mock implementation in tests
  • Production: Uses optimized CUDA kernel

6. Examples

6.1 Configuration Parameters

CLI Arguments Example (added in arguments.py):

--experimental-attention-variant dsa          # Enable DSA (DeepSeek Sparse Attention)
--dsa-indexer-n-heads 8                       # Number of indexer heads (default: num-attention-heads)
--dsa-indexer-head-dim 64                     # Dimension per indexer head (default: kv-channels)
--dsa-indexer-topk 32                         # Top-k tokens to select per query
--dsa-indexer-loss-coeff 1.0                # Coefficient for KL divergence loss (0 = disabled)
--dsa-indexer-use-sparse-loss                 # Use sparse KL loss (only on top-k positions)

TransformerConfig Example:

config = MLATransformerConfig(
    # ... standard MLA params ...
    experimental_attention_variant='dsa',      # 'dsa' | 'gated_delta_net' | None
    dsa_indexer_n_heads=8,                    # Must divide by TP size
    dsa_indexer_head_dim=64,                  # Typically same as kv_channels
    dsa_indexer_topk=32,                      # k in O(n·k) complexity
    dsa_indexer_loss_coeff=1.0,             # Typical range: 0.0001 - 0.01
    dsa_indexer_use_sparse_loss=False,        # True = sparse, False = dense KL loss
)

6.2 Example Usage

Training a GPT model with DSA:

python pretrain_gpt.py \
    --num-layers 32 \
    --hidden-size 4096 \
    --num-attention-heads 32 \
    --seq-length 8192 \
    \
    # Enable Multi-Latent Attention (required for DSA)
    --multi-latent-attention \
    --q-lora-rank 512 \
    --kv-lora-rank 512 \
    --qk-head-dim 128 \
    --qk-pos-emb-head-dim 64 \
    --v-head-dim 128 \
    \
    # Enable DeepSeek Sparse Attention
    --experimental-attention-variant dsa \
    --dsa-indexer-n-heads 16 \
    --dsa-indexer-head-dim 128 \
    --dsa-indexer-topk 256 \
    --dsa-indexer-loss-coeff 0.001 \
    \
    # Standard training args
    --micro-batch-size 1 \
    --global-batch-size 512 \
    --lr 1.0e-4 \
    --train-iters 100000 \
    --lr-decay-iters 100000 \
    --lr-decay-style cosine \
    --min-lr 1.0e-5 \
    --weight-decay 0.1 \
    --clip-grad 1.0 \
    --bf16

Expected Behavior:

  • Each layer will use sparse attention with top-256 tokens (instead of full 8192)
  • Indexer loss will be logged to TensorBoard as indexer loss

Further Reading

Signed-off-by: kunlunl <kunlunl@nvidia.com>
@kunlunl kunlunl requested review from a team as code owners December 1, 2025 09:21
@copy-pr-bot

copy-pr-bot Bot commented Dec 1, 2025

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@kunlunl kunlunl changed the title Add support for DSA DeepSeek V3.2 support Dec 1, 2025
@fzyzcjy

fzyzcjy commented Dec 2, 2025

Copy link
Copy Markdown

Hi, may I know the estimated time for this to be merged? Thanks!

@yanring yanring added module: moe Expert Review [deprecated] Apply this label to indicate that your PR is ready for expert review. labels Dec 3, 2025
####################
# attention variant
####################
experimental_attention_variant: Optional[str] = None

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, thanks for the PR! I have a suggestion regarding the DSA-related arguments. Could we consider adding them directly to the MLATransformerConfigclass? This would provide a more integrated and centralized configuration approach.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure whether all future attention variants will be built on top of MLA...

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree this should be in MLATransformerConfig. The experimental_attention_variant thing should not be used going forward so we shouldn't need to worry about any that aren't built on top of MLA.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is another linear attention variant #1989 that has not been merged, which doesn't use MLA, we though it's better to not modify args each time we want to add or remove an attention variant...

@Wineheart-Taro

Copy link
Copy Markdown

Hi, I have a question regarding the order of the RoPE-applied dimensions for Q and K in the DSA implementation. I’ve described the details in this issue. Could you please help clarify this? Thank you! #2733

@snowmanwwg snowmanwwg added the dev2main: mbridge dev to main: this PR is needed in main for mbridge label Jan 6, 2026
@Phlip79 Phlip79 requested a review from deepakn94 January 9, 2026 19:23
@Phlip79

Phlip79 commented Jan 9, 2026

Copy link
Copy Markdown
Member

/ok to test 1a1522e

@Phlip79

Phlip79 commented Jan 9, 2026

Copy link
Copy Markdown
Member

/ok to test 31f4120

@Phlip79

Phlip79 commented Jan 15, 2026

Copy link
Copy Markdown
Member

/ok to test b8f9d0a

@Phlip79

Phlip79 commented Jan 15, 2026

Copy link
Copy Markdown
Member

/ok to test 15f1fac

@kunlunl

kunlunl commented Jan 16, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 271f29b

@kunlunl

kunlunl commented Jan 16, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 30b970e

@kunlunl

kunlunl commented Jan 16, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 6f0f976

@Meta-YZ

Meta-YZ commented Jan 28, 2026

Copy link
Copy Markdown

Hello, it's an excellent job. Could you please tell me if there are any one-click run scripts and/or recommend the configuration of the number of Gpus and parallel configuration

@kunlunl @Phlip79

@Phlip79

Phlip79 commented Jan 28, 2026

Copy link
Copy Markdown
Member

Hello, it's an excellent job. Could you please tell me if there are any one-click run scripts and/or recommend the configuration of the number of Gpus and parallel configuration

@kunlunl @Phlip79

I would recommend using Megatron-Bridge, a Megatron-Core "extension". Check out the list of supported models and recipes. Unfortunately, v3.2 is not yet supported. We have a PR but it's currently blocked. The v3 recipe might be helpful.

@xhjhggybz

Copy link
Copy Markdown

Hi, thank you very much for your great work.
However, I’ve encountered an issue where the current DSA implementation consumes a large amount of GPU memory when the sequence length is long.
I’m wondering whether there are any planned or ongoing efforts to optimize the memory usage for long-sequence scenarios.

Thanks a lot in advance for your help.

@kunlunl

kunlunl commented Jan 29, 2026

Copy link
Copy Markdown
Contributor Author

Hi, thank you very much for your great work. However, I’ve encountered an issue where the current DSA implementation consumes a large amount of GPU memory when the sequence length is long. I’m wondering whether there are any planned or ongoing efforts to optimize the memory usage for long-sequence scenarios.

Thanks a lot in advance for your help.

Yes. The large memory footprint comes from the unfused DSA and indexer, which generate many seq^2 tensors. We have ongoing PRs to integrate fused kernels to replace the unfused pytorch implementation, but it's still WIP and the fused kernel can only run in specific shape.
Here are three PRs:

@xhjhggybz

Copy link
Copy Markdown

Hi, thank you very much for your great work. However, I’ve encountered an issue where the current DSA implementation consumes a large amount of GPU memory when the sequence length is long. I’m wondering whether there are any planned or ongoing efforts to optimize the memory usage for long-sequence scenarios.
Thanks a lot in advance for your help.

Yes. The large memory footprint comes from the unfused DSA and indexer, which generate many seq^2 tensors. We have ongoing PRs to integrate fused kernels to replace the unfused pytorch implementation, but it's still WIP and the fused kernel can only run in specific shape. Here are three PRs:

Thanks a lot for the detailed explanation and for sharing the related PRs — this is very helpful!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

complexity: medium dev2main: mbridge dev to main: this PR is needed in main for mbridge Expert Review [deprecated] Apply this label to indicate that your PR is ready for expert review. Final Review PR is in the "final review" stage module: moe

Projects

None yet

Development

Successfully merging this pull request may close these issues.