[HiCache] Fix illegal memory access caused by stream race between FA3 forward and D2H transfer#20611
[HiCache] Fix illegal memory access caused by stream race between FA3 forward and D2H transfer#20611Weili-0234 wants to merge 3 commits intosgl-project:mainfrom
Conversation
…ransfer Ensure write_stream waits for both schedule and forward streams before reading KV data for D2H transfer. Without this, write_stream may read KV slots that FA3 on forward_stream is still writing to. Also expose last_write_finish_event so scheduler can fence forward_stream before the next forward kernel launch.
Register forward_stream as the producer stream for HiCache controller, and wait on last_write_finish_event in the overlap path before launching forward_batch_generation. This prevents FA3 persistent kernel and HiCache SM copy kernel from running concurrently on the GPU.
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the stability of the SGLang serving system, particularly when utilizing HiCache with FA3 on Hopper GPUs. By meticulously addressing two critical CUDA stream race conditions through precise GPU-side synchronization, it eliminates illegal memory access errors and system crashes that previously occurred under high eviction pressure. The changes ensure that data transfers and kernel executions are correctly ordered, preventing resource conflicts and data corruption without introducing performance overhead. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request addresses race conditions between HiCache's D2H transfers and the FA3 attention kernel by adding two CUDA stream synchronization points. The first change makes the write_stream wait for the forward_stream to prevent reading incomplete KV cache data. The second change ensures the forward_stream waits for the previous D2H transfer to finish, avoiding SM resource conflicts. The implementation in cache_controller.py and scheduler.py correctly introduces these synchronizations using wait_stream and wait_event. The logic appears sound and should resolve the reported illegal memory access issues.
|
Thank you very much for looking into this. |
|
@Weili-0234 @xiezhq-hermann Is there a fix PR available currently? |
This was unexpected to me as well and I could not find an explanation. I did observe:
The only change in the SASS was the addition of an "ACQBULK" instruction which is not documented by NVIDIA AFAICT. |
|
It seems that the problem showd in this pr hasn't been fixed now? |
Motivation
Fix two CUDA stream race conditions that cause illegal memory access (IMA) in FA3 persistent kernel when HiCache
write_throughpolicy is active with eviction pressure on Hopper GPUs. The bug is triggered when--hicache-write-policy write_through(the default) causes D2H transfers onwrite_streamto race with FA3 onforward_stream.This is a long-standing issue reported across multiple models and configurations (#19737, #14363, #11842, #13912, #18785, #18166, many of them closed due to inactivity but the bug they pointed out remains). As #14363 noted: "it is very difficult to reproduce" since the crash requires sufficient eviction pressure (usually caused by high enough request concurrency) to trigger the race condition. I found that setting
--mem-fraction-static 0.40reliably reproduces the crash within minutes by forcing frequent eviction and D2H transfers.Together with #20049 (lock_ref leak fix), this PR completes a set of three fixes that address the HiCache + FA3 crash:
write_stream.wait_stream(forward_stream)cache_controller.py(this PR)forward_stream.wait_event(last_write_finish_event)scheduler.py(this PR)flush_write_through_acks()in decode pathRoot cause: HiCache's
write_stream(D2H transfer) and the scheduler'sforward_stream(FA3 kernel) lack bidirectional GPU-side synchronization. This creates two independent race windows:D2H reads before forward finishes writing:
write_streamstarts reading KV cache slots that FA3 onforward_streamis still writing to. The original code only synchronized with the schedule stream viastart_event, missing theforward_streamentirely. Fix:write_stream.wait_stream(forward_stream)before D2H.Forward launches before D2H SM copy finishes: The next batch's FA3 persistent kernel (occupying all 132 SMs) launches on
forward_streamwhile the previous batch's HiCache SM copy kernel (kernel IO backend) is still running onwrite_stream. The concurrent SM resource conflict corrupts FA3's TMA descriptor uniform registers, causingCUDBG_EXCEPTION_WARP_ILLEGAL_ADDRESS. Fix:forward_stream.wait_event(last_write_finish_event)before forward.The diagram below shows the three CUDA streams in SGLang's overlap scheduler and where each fix adds synchronization:
SGLang's overlap scheduler uses three CUDA streams: DEFAULT (scheduling/memory ops), FORWARD (FA3 + MLP forward), and WRITE (HiCache D2H transfer). Each iteration, the CPU submits
run_batch(N)to FORWARD, thenpop_and_process(N-1)triggers D2H on WRITE for the previous batch's KV data. Without synchronization:write_stream.wait_stream(forward_stream)ensures D2H waitsforward_stream.wait_event(last_write_finish_event)ensures FA3 waitsWith both fixes, FA3 and D2H SM copy kernel never overlap on the GPU, eliminating the TMA descriptor corruption.
Fixes #19737, fixes #18785, fixes #14363
Related issues: #11842, #13912, #18166
Related PR: #20049 (independent fix for write-through lock_ref leak during decode)
Note: #19737's Workaround 2 (
current_stream().wait_event(finish_event)) fixes the same race but blocks the CPU thread synchronously. This PR usesforward_stream.wait_event()instead, which is a GPU-side fence that preserves overlap scheduling performance.Modifications
cache_controller.py:producer_streamandlast_write_finish_eventattributes toHiCacheControllerset_producer_stream()method for scheduler to register the forward streamstart_writing(): replacestart_event.record()+start_event.wait()pattern with explicitwrite_stream.wait_stream(current_stream)+write_stream.wait_stream(producer_stream), ensuring D2H waits for both schedule and forward streamslast_write_finish_eventafter D2H finish event is recordedscheduler.py:set_producer_stream(self.forward_stream)to register the forward stream with the cache controllerrun_batch): addforward_stream.wait_event(last_write_finish_event)beforeforward_batch_generation, ensuring the forward kernel doesn't launch until the previous D2H SM copy kernel completesDebugging with CUDA Coredump Analysis
We collected CUDA coredumps using
SGLANG_CUDA_COREDUMP=1from two crash scenarios on 8xH100 serving MiniMax-M2.5 (456B MoE, TP8 EP8) with--mem-fraction-static 0.40 --enable-hierarchical-cache --hicache-io-backend kernel --attention-backend fa3. All experiments had #20049 (lock_ref fix) applied to prevent OOM during sustained decode. In both cases the crash kernel isFlashAttnFwdSm90— the FA3 persistent kernel, Grid=(132,1,1) occupying all 132 SMs on H100.Crash A: with #20049 only (no stream sync), crashed at 2.2 min / 48 waves:
LDGSTS.E.BYPASS.LTC128B.128 desc[UR40]in FA3 mainloop (loading KV data via TMA)write_streamreads KV slots beforeforward_stream(FA3) finishes writing themCrash B: with #20049 +
write_stream.wait_stream(forward_stream), crashed at 20 min / 427 waves:STG.E desc[UR40][R4.64],R3in FA3 epilogue (writing attention output via TMA)-inf, softmax LSE init value)write_stream) and FA3 persistent kernel (onforward_stream) run concurrently, causing SM resource conflict that corrupts FA3's uniform registersThe different crash sites (mainloop LDGSTS vs epilogue STG) with the same symptom (UR40=0x0 TMA descriptor corruption) confirm two independent race conditions. The first fix eliminates the read-before-write race, exposing the SM resource conflict race which crashes at a different instruction.
CUDA coredump files and cuda-gdb analysis commands
Full cuda-gdb analysis output:
cuda-gdb-analysis.txtStress test reproduction script:
repro-sglang-cuda-crash.pyCoredump files:
cuda_coredump_5ac27cf368dd.175.1773213005(340 MB),cuda_coredump_5ac27cf368dd.173.1773213005(314 MB)write_stream.wait_streamonly):cuda_coredump_5e44d0e4f107.174.1773191203(303 MB)To analyze with cuda-gdb (requires CUDA Toolkit 12.x):
To reproduce and generate your own coredump:
Stress test script:
repro-sglang-cuda-crash.py— sends concurrent long-prompt requests to trigger eviction pressureStress Test with High Request Concurrency
Validated on 8xH100 SXM serving MiniMax-M2.5 (456B MoE). All experiments had #20049 (lock_ref fix) applied.
Command
Client:
Scripts:
repro_sustained.py(sustained multi-turn stress test),repro-sglang-cuda-crash.py(single-shot crash reproduction).write_stream.wait_streamonly--mem-fraction-static 0.40forces tight GPU memory, keeping token usage at 0.83 with eviction active throughoutwait_eventwait_eventis a GPU-side no-op when D2H has already completed before the next forward launch (the common case)Accuracy Tests
No changes to model forward code or kernel logic. The fix only adds GPU-side stream synchronization barriers.
Benchmarking and Profiling
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci