Fix NSA FA3 shape mismatch under DP Attention + EAGLE#24235
Fix NSA FA3 shape mismatch under DP Attention + EAGLE#24235junliu-mde wants to merge 2 commits intosgl-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces deferred attention metadata initialization to handle Data Parallel (DP) attention padding, specifically for Native Sparse Attention (NSA) and FlashAttention-3 (FA3) backends. By tracking the original batch dimensions before padding, the system ensures that attention kernels operate on the correct token subsets while preserving the padded shapes needed for downstream MLP synchronization. Feedback suggests refining the backend type check in ModelRunner to be more comprehensive and consistent with other parts of the codebase.
| if forward_batch.input_ids is not None | ||
| else 0 | ||
| ) | ||
| if attn_backend.__class__.__name__ != "NativeSparseAttnBackend": |
There was a problem hiding this comment.
The check attn_backend.__class__.__name__ != "NativeSparseAttnBackend" is a bit brittle as it only checks for one specific class name. A similar check in python/sglang/srt/speculative/eagle_worker.py uses a helper function _is_native_sparse_attn_backend which checks for both NativeSparseAttnBackend and NativeSparseAttnMultiStepBackend. To improve robustness and maintainability, consider making this check more comprehensive by checking against a set of class names. This would make the code more resilient to future changes, such as if a wrapper class is introduced.
| if attn_backend.__class__.__name__ != "NativeSparseAttnBackend": | |
| if attn_backend.__class__.__name__ not in ("NativeSparseAttnBackend", "NativeSparseAttnMultiStepBackend"): |
Keep NSA imports at module scope and remove ambiguous or unused local variables so the touched files pass ruff.
Align NSA metadata padding with DP attention token padding and guard empty DP ranks before FA3 and DeepGEMM metadata paths.
3545ca3 to
f2e386d
Compare
The scheduler_metadata buffer precomputed in `_compute_scheduler_metadata` (introduced by PR sgl-project#21104 to avoid per-layer `prepare_varlen_num_blocks`) can become inconsistent with the `num_splits` the C++ `mha_fwd` kernel derives from the live `cache_seqlens` once decode advances. The mismatch triggers an out-of-bounds read in the FA3 split-KV combine kernel and surfaces as a CUDA illegal-memory-access at `flash_fwd_combine_launch_template.h:52`. Reproduces with Qwen3-0.6B + `--enable-dp-attention --dp 8 --tp 8 --chunked-prefill-size 131072` on H200 after ~65 decode steps. Single-GPU and TP-only paths are unaffected. Skip the precompute when DP attention is on and let the C++ kernel recompute its own metadata per layer. PR sgl-project#21104's optimization is preserved on every other path. PR sgl-project#24235 had previously addressed a narrower variant on NSA + EAGLE. Co-authored-by: Cursor <cursoragent@cursor.com>
The scheduler_metadata buffer precomputed in `_compute_scheduler_metadata` (introduced by PR sgl-project#21104 to avoid per-layer `prepare_varlen_num_blocks`) can become inconsistent with the `num_splits` the C++ `mha_fwd` kernel derives from the live `cache_seqlens` once decode advances. The mismatch triggers an out-of-bounds read in the FA3 split-KV combine kernel and surfaces as a CUDA illegal-memory-access at `flash_fwd_combine_launch_template.h:52`. Reproduces with Qwen3-0.6B + `--enable-dp-attention --dp 8 --tp 8 --chunked-prefill-size 131072` on H200 after ~65 decode steps. Single-GPU and TP-only paths are unaffected. Skip the precompute when DP attention is on and let the C++ kernel recompute its own metadata per layer. PR sgl-project#21104's optimization is preserved on every other path. PR sgl-project#24235 had previously addressed a narrower variant on NSA + EAGLE.
Summary
fix #24233
Test plan
--tp 8 --dp 4 --enable-dp-attention --cuda-graph-max-bs 2 --speculative-algorithm NEXTN --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 --page-size 1).python3 test-dpa-repro.py http://localhost:30000 1v3reproducesRuntimeError: batch_size must be equal to batch_size_k.test-dpa-repro.py ... 1v3client returnsTotal: 4/4 ok, 0 failedafter warmup. The local repro script exits non-zero on the fixed path because it encodes "crash confirmed" as success.batch_size must be equal, 0 FA3 metadata/page table mismatch, 0 scheduler exception, 0 HTTP 500, and 0 traceback/assertion.