[CPU][sgl-kernel] extend_attention_cpu and flash_attn_varlen_func: fix nan for large seq#22434
Conversation
|
/tag-run-ci-label |
There was a problem hiding this comment.
Code Review
This pull request addresses a bug in the causal masking logic within the extend_attention_kernel_impl function in sgl-kernel/csrc/cpu/extend.cpp. The original masking condition was updated to correctly handle cases where BLOCK_M > BLOCK_N/2, which previously led to future keys not being masked and potential out-of-bounds writes. A new test helper _test_extend_attention_fixed_lens and a specific test case test_extend_attention_large_seq_causal_mask were added in test/srt/cpu/test_extend.py to reproduce and validate the fix for this issue. A review comment suggests refactoring the new test helper to reduce code duplication with _test_extend_attention_once for improved maintainability.
|
@mingfeima could you please help review this PR? Xeon CI has passed. |
extend_attention_cpu and flash_attn_varlen_func: fix nan for large seq
…: fix `nan` for large seq (sgl-project#22434) Co-authored-by: Ma Mingfei <mingfei.ma@intel.com>
…: fix `nan` for large seq (sgl-project#22434) Co-authored-by: Ma Mingfei <mingfei.ma@intel.com>
…: fix `nan` for large seq (sgl-project#22434) Co-authored-by: Ma Mingfei <mingfei.ma@intel.com>
…: fix `nan` for large seq (sgl-project#22434) Co-authored-by: Ma Mingfei <mingfei.ma@intel.com>
Motivation
Fixes nan for large input (> 4096) in the
extend_attention_cpuandflash_attn_varlen_funckernel.Below is the error message raised when feeding a large input to model:
#20051 also reported this issue.
Modifications
last_colis negativeFor input lens > 4096,
BLOCK_M = 512andBLOCK_N = 768is selected. Atmb=1,m = mb * BLOCK_M = 512, thusnum_keys = m + BLOCK_M = 1024.For the
nloop, the valid n value isn = 0(0<1024) andn = 768(768<1024), so there's two n-blocks:[0, 767](size 768)[768, 1023](size 256)The second block can be entirely future for early rows (e.g. row 0 at absolute query pos 512).
Before the fix, for the second block
[768, 1024], thelast_colwill be negative, the previous code will write before row_ptr, resulting innan. This PR will setlast_colto-1in this case and mask the entire row.Before the fix, for the first block,
[513, 767]need to be masked but the previous code only handles the last block[768, 1023], making the result wrong. This PR will mask[513, 767]correctly.Accuracy Tests
Before the fix:
After the fix:
Review and Merge Process
/tag-and-rerun-ci,/tag-run-ci-label,/rerun-failed-ci