DeepSeek V3.2 support#2440
Conversation
Signed-off-by: kunlunl <kunlunl@nvidia.com>
|
Hi, may I know the estimated time for this to be merged? Thanks! |
| #################### | ||
| # attention variant | ||
| #################### | ||
| experimental_attention_variant: Optional[str] = None |
There was a problem hiding this comment.
Hi, thanks for the PR! I have a suggestion regarding the DSA-related arguments. Could we consider adding them directly to the MLATransformerConfigclass? This would provide a more integrated and centralized configuration approach.
There was a problem hiding this comment.
Not sure whether all future attention variants will be built on top of MLA...
There was a problem hiding this comment.
I agree this should be in MLATransformerConfig. The experimental_attention_variant thing should not be used going forward so we shouldn't need to worry about any that aren't built on top of MLA.
There was a problem hiding this comment.
There is another linear attention variant #1989 that has not been merged, which doesn't use MLA, we though it's better to not modify args each time we want to add or remove an attention variant...
|
Hi, I have a question regarding the order of the RoPE-applied dimensions for Q and K in the DSA implementation. I’ve described the details in this issue. Could you please help clarify this? Thank you! #2733 |
|
/ok to test 1a1522e |
|
/ok to test 31f4120 |
|
/ok to test b8f9d0a |
|
/ok to test 15f1fac |
|
/ok to test 271f29b |
|
/ok to test 30b970e |
|
/ok to test 6f0f976 |
I would recommend using Megatron-Bridge, a Megatron-Core "extension". Check out the list of supported models and recipes. Unfortunately, v3.2 is not yet supported. We have a PR but it's currently blocked. The v3 recipe might be helpful. |
|
Hi, thank you very much for your great work. Thanks a lot in advance for your help. |
Yes. The large memory footprint comes from the unfused DSA and indexer, which generate many seq^2 tensors. We have ongoing PRs to integrate fused kernels to replace the unfused pytorch implementation, but it's still WIP and the fused kernel can only run in specific shape.
|
Thanks a lot for the detailed explanation and for sharing the related PRs — this is very helpful! |
DeepSeek V3.2 Sparse Attention Support
dev PR: #2154
1. TL;DR
What: This PR adds support for DeepSeek V3.2-style sparse attention (DSA) to Megatron-LM, enabling models to use learned attention sparsity patterns via a lightweight indexer module.
Why: Dense attention has O(n²) complexity which becomes computationally prohibitive for long sequences. DSA reduces this by learning to predict which key-value pairs are most relevant for each query, allowing the model to attend to only top-k tokens instead of all tokens.
Impact: Users can now train models with DeepSeek V3.2's sparse attention mechanism in Megatron-LM, which combines Multi-Latent Attention with a trainable sparse indexer.
2. Big Picture
2.1 Before vs After Architecture
graph TB subgraph "Before: Standard Dense Attention" A1[Hidden States] --> B1[QKV Projection] B1 --> C1[Dense Attention] C1 --> D1[Output Projection] end subgraph "After: DeepSeek Sparse Attention" A2[Hidden States] --> B2[QKV Projection] A2 --> E2[DSA Indexer] B2 --> E2 E2 --> F2[Top-K Selection] B2 --> G2[Sparse Attention] F2 --> G2 G2 --> D2[Output Projection] E2 --> H2[Indexer Loss] H2 --> D2 end style E2 fill:#90EE90 style F2 fill:#90EE90 style H2 fill:#FFD700Key Changes:
2.2 Change Scope Summary
megatron/core/transformer/experimental_attention_variant/dsa.pymegatron/core/models/gpt/experimental_attention_variant_module_specs.pytests/unit_tests/transformer/test_attention_variant_dsa.pymegatron/core/transformer/multi_latent_attention.pymegatron/core/transformer/transformer_config.pymegatron/training/arguments.pymegatron/training/training.pymegatron/core/models/gpt/gpt_layer_specs.pygpt_builders.py3. Key Design Points
Core Abstractions Introduced:
DSAIndexer: Computes index scores to identify top-k most relevant tokensx[seqlen, batch, hidden_size] + compressed queryqr[seqlen, batch, q_lora_rank]DSAttention: Sparse attention mechanism using indexer outputsDSAIndexerand applies sparse attention kernelDSAIndexerLossAutoScaler: Custom autograd functionInterface Contracts:
Important Invariants:
x,qr) are always detached - gradients don't flow back to main modelDSAIndexerLossAutoScaler.apply()- backpropagates separatelymulti_latent_attention=Trueandcontext_parallel_size=14. Execution Path Deep Dive
4.1 Entry Point
DSA is triggered when creating a GPT model with
--experimental-attention-variant dsaflag:4.2 Data Flow
graph TD A["Input: hidden_states<br/>[sq, b, hidden]"] --> B["MLA Q Compression<br/>linear_q_proj→linear_q_down_proj"] A --> C["MLA KV Compression<br/>linear_kv_down_proj"] B --> D["q_compressed<br/>[sq, b, q_lora_rank]"] C --> E["kv_compressed<br/>[sq, b, kv_lora_rank]"] D --> F["MLA Q Upsampling<br/>linear_q_up_proj + RoPE"] E --> G["MLA KV Upsampling<br/>linear_kv_up_proj + RoPE"] F --> H["query<br/>[sq, b, np, hn]"] G --> I["key<br/>[sk, b, np, hn]"] G --> J["value<br/>[sk, b, np, hnv]"] A --> K["x.detach()"] D --> L["q_compressed.detach()"] K --> M["DSAIndexer"] L --> M M --> N["Indexer Q Proj<br/>linear_wq_b<br/>[sq, b, index_n_heads, index_head_dim]"] K --> O["Indexer K Proj<br/>linear_wk + k_norm<br/>[sk, b, index_head_dim]"] K --> P["Indexer Weights<br/>linear_weights_proj<br/>[sq, b, index_n_heads]"] N --> Q["Apply RoPE"] O --> R["Apply RoPE"] Q --> S["rotate_activation<br/>(Hadamard transform)"] R --> T["rotate_activation<br/>(Hadamard transform)"] S --> U["Index Scores<br/>q @ k^T → ReLU → weighted sum<br/>[b, sq, sk]"] T --> U P --> U U --> V["TopK Selection<br/>[b, sq, index_topk]"] H --> W["Sparse Attention"] I --> W J --> W V --> W W --> X["attention_output<br/>[sq, b, hidden]"] U --> Y["KL Divergence Loss<br/>KL(true_attn || index_scores)"] V --> Y H --> Y I --> Y Y --> Z["indexer_loss<br/>scalar"] X --> AA["DSAIndexerLossAutoScaler.apply"] Z --> AA AA --> AB["Final Output<br/>(with loss attached)"] style K fill:#FFE4B5 style L fill:#FFE4B5 style M fill:#90EE90 style U fill:#87CEEB style V fill:#87CEEB style W fill:#FFD700 style Y fill:#FF6347 style Z fill:#FF6347 style AA fill:#DDA0DD5. Module Relationships
classDiagram class TransformerConfig { +int num_layers +int hidden_size +str experimental_attention_variant } class MLATransformerConfig { +int q_lora_rank +int kv_lora_rank +int dsa_indexer_n_heads +int dsa_indexer_head_dim +int dsa_indexer_topk +float dsa_indexer_loss_coeff } class Attention { <<abstract>> +forward()* } class MultiLatentAttention { +get_query_key_value_tensors() +forward() } class MLASelfAttention { +linear_q_proj +linear_kv_down_proj +core_attention +get_query_key_value_tensors(return_compressed_tensors) } class DSAttention { +indexer: DSAIndexer +softmax_scale: float +forward(query, key, value, x, qr, ...) } class DSAIndexer { +linear_wq_b +linear_wk +k_norm +linear_weights_proj +rotary_pos_emb +forward(x, qr, mask) +forward_with_scores(x, qr, mask) -_apply_rope() -_compute_index_scores() } class DSAIndexerSubmodules { +linear_wq_b: ModuleSpec +linear_wk: ModuleSpec +k_norm: ModuleSpec +linear_weights_proj: ModuleSpec } class DSAttentionSubmodules { +indexer: ModuleSpec } class MLASelfAttentionSubmodules { +core_attention: ModuleSpec +linear_q_proj +linear_kv_down_proj +q_layernorm +kv_layernorm } class RotaryEmbedding { +forward(seq_len) } class DSAIndexerLossAutoScaler { <<autograd.Function>> +forward(output, loss)$ +backward(grad_output)$ +set_loss_scale(scale)$ +main_loss_backward_scale$ } class DSAIndexerLossLoggingHelper { +save_loss_to_tracker()$ +reduce_loss_in_tracker()$ +track_indexer_metrics()$ +tracker: dict$ } TransformerConfig <|-- MLATransformerConfig : extends Attention <|-- MultiLatentAttention : extends MultiLatentAttention <|-- MLASelfAttention : extends Attention <|-- DSAttention : implements (core_attention) MLASelfAttention --> DSAttention : uses as core_attention MLASelfAttention --> MLASelfAttentionSubmodules : configured by DSAttention --> DSAIndexer : contains DSAttention --> DSAttentionSubmodules : configured by DSAIndexer --> DSAIndexerSubmodules : configured by DSAIndexer --> RotaryEmbedding : uses DSAttention ..> DSAIndexerLossAutoScaler : uses DSAttention ..> DSAIndexerLossLoggingHelper : logs to MLASelfAttention ..> MLATransformerConfig : reads config DSAttention ..> MLATransformerConfig : reads config DSAIndexer ..> MLATransformerConfig : reads configKey Relationships:
Composition:
MLASelfAttentioncontainsDSAttentionas itscore_attentionmoduleDSAttentioncontainsDSAIndexerfor computing sparse indicesUtility Classes:
DSAIndexerLossAutoScaler: Custom autograd for loss attachmentDSAIndexerLossLoggingHelper: Singleton for collecting losses across layersNew Dependencies Introduced:
fast_hadamard_transform(optional): For Hadamard rotation activation6. Examples
6.1 Configuration Parameters
CLI Arguments Example (added in
arguments.py):TransformerConfig Example:
6.2 Example Usage
Training a GPT model with DSA:
python pretrain_gpt.py \ --num-layers 32 \ --hidden-size 4096 \ --num-attention-heads 32 \ --seq-length 8192 \ \ # Enable Multi-Latent Attention (required for DSA) --multi-latent-attention \ --q-lora-rank 512 \ --kv-lora-rank 512 \ --qk-head-dim 128 \ --qk-pos-emb-head-dim 64 \ --v-head-dim 128 \ \ # Enable DeepSeek Sparse Attention --experimental-attention-variant dsa \ --dsa-indexer-n-heads 16 \ --dsa-indexer-head-dim 128 \ --dsa-indexer-topk 256 \ --dsa-indexer-loss-coeff 0.001 \ \ # Standard training args --micro-batch-size 1 \ --global-batch-size 512 \ --lr 1.0e-4 \ --train-iters 100000 \ --lr-decay-iters 100000 \ --lr-decay-style cosine \ --min-lr 1.0e-5 \ --weight-decay 0.1 \ --clip-grad 1.0 \ --bf16Expected Behavior:
indexer lossFurther Reading