You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Below context lengths 2K, sparse attention is the same as non-sparse attention and we can skip the logits computation and directly generates the indices for sparse MLA kernel or use MHA when possible
The current flashmla_decode kernel is not well optimized on B200. So a separate dequant kernel + flashmla_sparse_bf16 works better for prefill + fp8 kvcache if the kv cache is not too long compared to the q sequence length. The heuristics will need to be updated after new optimizations to either the prefill or decode kernels, making it a bit hard to use in practice. Detailed analysis is here
[Decode] Move deep_gemm.get_paged_mqa_logits_metadata to init time, similar to attention kernel metadata compute
[Prefill] Optimize _get_topk_ragged where there are a lot of small kernels. Try multi-stream, torch.compile, and add new kernels when necessary.
[MTP] Enable nextn = 2/4 in deep_gemm.fp8_paged_mqa_logits, which is faster than the current implementation which uses the kernel with nextn = 1 regardless of mtp size.
Attention algorithm
Link to original table
The parts highlighted in blue is work that has been done or in progress.
To summarize:
Kernel optimizations
torch.cat([q_nope, q_rope])by either writing a fast triton/cuda kernel or using torch.compile. It's used for prefill/decode but the prefill one is much bigger and has more room for optimizations. The trtllm kernel supports separateq_nopeandq_rope, but flashmla doesn't. [DeepseekV32]: use_concat_mla_absorb_q_generalto replacetorch.cat#12215 [Deepseek V3.2] Use torch.compile to speed up torch.cat in nsa #13022Indexer optimizations
Min latency optimizations