Skip to content

[CPU][sgl-kernel] extend_attention_cpu and flash_attn_varlen_func: fix nan for large seq#22434

Merged
mingfeima merged 10 commits intosgl-project:mainfrom
chunyuan-w:chunyuan/pr_fix_nan_large_seq
Apr 17, 2026
Merged

[CPU][sgl-kernel] extend_attention_cpu and flash_attn_varlen_func: fix nan for large seq#22434
mingfeima merged 10 commits intosgl-project:mainfrom
chunyuan-w:chunyuan/pr_fix_nan_large_seq

Conversation

@chunyuan-w
Copy link
Copy Markdown
Contributor

@chunyuan-w chunyuan-w commented Apr 9, 2026

Motivation

Fixes nan for large input (> 4096) in the extend_attention_cpu and flash_attn_varlen_func kernel.

Below is the error message raised when feeding a large input to model:

  File "/home/user/sglang/python/sglang/srt/layers/sampler.py", line 477, in top_k_top_p_min_p_sampling_from_probs_torch
    sampled_index = torch.multinomial(probs_sort, num_samples=1)
RuntimeError: probability tensor contains either `inf`, `nan` or element < 0

#20051 also reported this issue.

Modifications

  1. Fixes the case where last_col is negative
  2. Fixes the condition to apply the causal mask

For input lens > 4096, BLOCK_M = 512 and BLOCK_N = 768 is selected. At mb=1, m = mb * BLOCK_M = 512, thus num_keys = m + BLOCK_M = 1024.
For the n loop, the valid n value is n = 0 (0<1024) and n = 768 (768<1024), so there's two n-blocks:

  • first block: keys [0, 767] (size 768)
  • second block: keys [768, 1023] (size 256)

The second block can be entirely future for early rows (e.g. row 0 at absolute query pos 512).
Before the fix, for the second block [768, 1024], the last_col will be negative, the previous code will write before row_ptr, resulting in nan. This PR will set last_col to -1 in this case and mask the entire row.
Before the fix, for the first block, [513, 767] need to be masked but the previous code only handles the last block [768, 1023], making the result wrong. This PR will mask [513, 767] correctly.

Accuracy Tests

python -u test/srt/cpu/test_extend.py -k test_extend_attention_large_seq_causal_mask

python -u test/srt/cpu/test_flash_attn.py -k test_flash_attn_large_seq_causal_mask

Before the fix:

Mismatched elements: 281210 / 2560000 (11.0%)
Greatest absolute difference: nan at index (1601, 7, 0) (up to 0.01 allowed)
Greatest relative difference: nan at index (1536, 4, 0) (up to 0.01 allowed)


Mismatched elements: 358528 / 5119488 (7.0%)
Greatest absolute difference: nan at index (1665, 2, 0) (up to 0.01 allowed)
Greatest relative difference: nan at index (1665, 2, 0) (up to 0.01 allowed)
E

After the fix:

[CI Test Method] TestExtendAttention.test_extend_attention_large_seq_causal_mask
.
----------------------------------------------------------------------
Ran 1 test in 0.506s

OK

[CI Test Method] TestFlashAttn.test_flash_attn_large_seq_causal_mask
.
----------------------------------------------------------------------
Ran 1 test in 1.033s

OK

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

@chunyuan-w
Copy link
Copy Markdown
Contributor Author

/tag-run-ci-label

@github-actions github-actions Bot added the run-ci label Apr 9, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a bug in the causal masking logic within the extend_attention_kernel_impl function in sgl-kernel/csrc/cpu/extend.cpp. The original masking condition was updated to correctly handle cases where BLOCK_M > BLOCK_N/2, which previously led to future keys not being masked and potential out-of-bounds writes. A new test helper _test_extend_attention_fixed_lens and a specific test case test_extend_attention_large_seq_causal_mask were added in test/srt/cpu/test_extend.py to reproduce and validate the fix for this issue. A review comment suggests refactoring the new test helper to reduce code duplication with _test_extend_attention_once for improved maintainability.

Comment thread test/srt/cpu/test_extend.py Outdated
@chunyuan-w
Copy link
Copy Markdown
Contributor Author

@mingfeima could you please help review this PR? Xeon CI has passed.

@chunyuan-w chunyuan-w changed the title [CPU][sgl-kernel] extend_attention_cpu: fix nan for large seq [CPU][sgl-kernel] extend_attention_cpu and flash_attn_varlen_func: fix nan for large seq Apr 14, 2026
@mingfeima mingfeima merged commit 6c89214 into sgl-project:main Apr 17, 2026
81 of 103 checks passed
jmamou pushed a commit to jmamou/sglang that referenced this pull request Apr 20, 2026
…: fix `nan` for large seq (sgl-project#22434)

Co-authored-by: Ma Mingfei <mingfei.ma@intel.com>
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Apr 22, 2026
…: fix `nan` for large seq (sgl-project#22434)

Co-authored-by: Ma Mingfei <mingfei.ma@intel.com>
zhangying098 pushed a commit to zhangying098/sglang that referenced this pull request Apr 23, 2026
…: fix `nan` for large seq (sgl-project#22434)

Co-authored-by: Ma Mingfei <mingfei.ma@intel.com>
kyx1999 pushed a commit to KMSorSMS/sglang that referenced this pull request Apr 27, 2026
…: fix `nan` for large seq (sgl-project#22434)

Co-authored-by: Ma Mingfei <mingfei.ma@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants