Skip to content

[HiCache] Fix illegal memory access caused by stream race between FA3 forward and D2H transfer#20611

Open
Weili-0234 wants to merge 3 commits intosgl-project:mainfrom
Weili-0234:fix/hicache-fa3-stream-sync
Open

[HiCache] Fix illegal memory access caused by stream race between FA3 forward and D2H transfer#20611
Weili-0234 wants to merge 3 commits intosgl-project:mainfrom
Weili-0234:fix/hicache-fa3-stream-sync

Conversation

@Weili-0234
Copy link
Copy Markdown

@Weili-0234 Weili-0234 commented Mar 15, 2026

Motivation

Fix two CUDA stream race conditions that cause illegal memory access (IMA) in FA3 persistent kernel when HiCache write_through policy is active with eviction pressure on Hopper GPUs. The bug is triggered when --hicache-write-policy write_through (the default) causes D2H transfers on write_stream to race with FA3 on forward_stream.

This is a long-standing issue reported across multiple models and configurations (#19737, #14363, #11842, #13912, #18785, #18166, many of them closed due to inactivity but the bug they pointed out remains). As #14363 noted: "it is very difficult to reproduce" since the crash requires sufficient eviction pressure (usually caused by high enough request concurrency) to trigger the race condition. I found that setting --mem-fraction-static 0.40 reliably reproduces the crash within minutes by forcing frequent eviction and D2H transfers.

Together with #20049 (lock_ref leak fix), this PR completes a set of three fixes that address the HiCache + FA3 crash:

Fix Description Location
write_stream.wait_stream(forward_stream) D2H must wait for FA3 to finish writing KV data cache_controller.py (this PR)
forward_stream.wait_event(last_write_finish_event) FA3 must wait for D2H SM copy kernel to finish scheduler.py (this PR)
flush_write_through_acks() in decode path Release lock_ref during decode to prevent OOM #20049 (independent PR)

Root cause: HiCache's write_stream (D2H transfer) and the scheduler's forward_stream (FA3 kernel) lack bidirectional GPU-side synchronization. This creates two independent race windows:

  1. D2H reads before forward finishes writing: write_stream starts reading KV cache slots that FA3 on forward_stream is still writing to. The original code only synchronized with the schedule stream via start_event, missing the forward_stream entirely. Fix: write_stream.wait_stream(forward_stream) before D2H.

  2. Forward launches before D2H SM copy finishes: The next batch's FA3 persistent kernel (occupying all 132 SMs) launches on forward_stream while the previous batch's HiCache SM copy kernel (kernel IO backend) is still running on write_stream. The concurrent SM resource conflict corrupts FA3's TMA descriptor uniform registers, causing CUDBG_EXCEPTION_WARP_ILLEGAL_ADDRESS. Fix: forward_stream.wait_event(last_write_finish_event) before forward.

The diagram below shows the three CUDA streams in SGLang's overlap scheduler and where each fix adds synchronization:

cab05df9ecba316a81dac6933cf6b5c4

SGLang's overlap scheduler uses three CUDA streams: DEFAULT (scheduling/memory ops), FORWARD (FA3 + MLP forward), and WRITE (HiCache D2H transfer). Each iteration, the CPU submits run_batch(N) to FORWARD, then pop_and_process(N-1) triggers D2H on WRITE for the previous batch's KV data. Without synchronization:

  • WRITE's D2H may start reading KV data that FORWARD's FA3 hasn't finished writing → write_stream.wait_stream(forward_stream) ensures D2H waits
  • FORWARD's next FA3 may launch while WRITE's SM copy kernel is still running → forward_stream.wait_event(last_write_finish_event) ensures FA3 waits

With both fixes, FA3 and D2H SM copy kernel never overlap on the GPU, eliminating the TMA descriptor corruption.

Fixes #19737, fixes #18785, fixes #14363
Related issues: #11842, #13912, #18166
Related PR: #20049 (independent fix for write-through lock_ref leak during decode)

Note: #19737's Workaround 2 (current_stream().wait_event(finish_event)) fixes the same race but blocks the CPU thread synchronously. This PR uses forward_stream.wait_event() instead, which is a GPU-side fence that preserves overlap scheduling performance.

Modifications

cache_controller.py:

  • Add producer_stream and last_write_finish_event attributes to HiCacheController
  • Add set_producer_stream() method for scheduler to register the forward stream
  • In start_writing(): replace start_event.record() + start_event.wait() pattern with explicit write_stream.wait_stream(current_stream) + write_stream.wait_stream(producer_stream), ensuring D2H waits for both schedule and forward streams
  • Expose last_write_finish_event after D2H finish event is recorded

scheduler.py:

  • In HiRadixCache initialization: call set_producer_stream(self.forward_stream) to register the forward stream with the cache controller
  • In overlap path (run_batch): add forward_stream.wait_event(last_write_finish_event) before forward_batch_generation, ensuring the forward kernel doesn't launch until the previous D2H SM copy kernel completes

Debugging with CUDA Coredump Analysis

We collected CUDA coredumps using SGLANG_CUDA_COREDUMP=1 from two crash scenarios on 8xH100 serving MiniMax-M2.5 (456B MoE, TP8 EP8) with --mem-fraction-static 0.40 --enable-hierarchical-cache --hicache-io-backend kernel --attention-backend fa3. All experiments had #20049 (lock_ref fix) applied to prevent OOM during sustained decode. In both cases the crash kernel is FlashAttnFwdSm90 — the FA3 persistent kernel, Grid=(132,1,1) occupying all 132 SMs on H100.

Crash A: with #20049 only (no stream sync), crashed at 2.2 min / 48 waves:

  • Crash instruction: LDGSTS.E.BYPASS.LTC128B.128 desc[UR40] in FA3 mainloop (loading KV data via TMA)
  • UR40 = 0x0, UR41 = 0x0 (null TMA descriptors)
  • Root cause: write_stream reads KV slots before forward_stream (FA3) finishes writing them

Crash B: with #20049 + write_stream.wait_stream(forward_stream), crashed at 20 min / 427 waves:

  • Crash instruction: STG.E desc[UR40][R4.64],R3 in FA3 epilogue (writing attention output via TMA)
  • UR40 = 0x0, UR41 = 0x0 (null TMA descriptors), R3 = 0xFF800000 (-inf, softmax LSE init value)
  • Two independent warps on different SMs (SM 4 and SM 6) both triggered the error — rules out hardware fault, confirms systematic TMA descriptor corruption
  • Root cause: HiCache SM copy kernel (on write_stream) and FA3 persistent kernel (on forward_stream) run concurrently, causing SM resource conflict that corrupts FA3's uniform registers

The different crash sites (mainloop LDGSTS vs epilogue STG) with the same symptom (UR40=0x0 TMA descriptor corruption) confirm two independent race conditions. The first fix eliminates the read-before-write race, exposing the SM resource conflict race which crashes at a different instruction.

CUDA coredump files and cuda-gdb analysis commands

Full cuda-gdb analysis output: cuda-gdb-analysis.txt
Stress test reproduction script: repro-sglang-cuda-crash.py

Coredump files:

To analyze with cuda-gdb (requires CUDA Toolkit 12.x):

cuda-gdb -batch \
  -ex "target cudacore cuda_coredump_5ac27cf368dd.175.1773213005" \
  -ex "info cuda kernels" \
  -ex "info cuda threads" \
  -ex "cuda kernel 0" \
  -ex "x/i \$pc" \
  -ex "info registers UR40 UR41 R3 R4 R5"

To reproduce and generate your own coredump:

# 1. Apply PR #20049 first (prevents OOM that masks the stream race)

# 2. Start server with tight cache + coredump collection
SGLANG_CUDA_COREDUMP=1 PYTHONUNBUFFERED=1 python3 -m sglang.launch_server \
  --model-path <any-model> \
  --enable-hierarchical-cache \
  --mem-fraction-static 0.40 \
  --hicache-io-backend kernel \
  --attention-backend fa3 \
  --port 8000

# 3. Send concurrent requests (crash within minutes on Hopper SXM)
python3 -m sglang.bench_serving \
  --backend sglang \
  --port 8000 \
  --dataset-name random \
  --random-input 1024 \
  --random-output 256 \
  --num-prompts 512 \
  --request-rate 10

# 4. Coredump saved to logs/cuda-coredumps/ after crash

Stress test script: repro-sglang-cuda-crash.py — sends concurrent long-prompt requests to trigger eviction pressure

Stress Test with High Request Concurrency

Validated on 8xH100 SXM serving MiniMax-M2.5 (456B MoE). All experiments had #20049 (lock_ref fix) applied.

Command

python3 -m sglang.launch_server \
  --model-path MiniMaxAI/MiniMax-M2.5 \
  --tp 8 --ep 8 \
  --mem-fraction-static 0.40 \
  --enable-hierarchical-cache \
  --hicache-ratio 2.0 \
  --hicache-size 64 \
  --hicache-write-policy write_through \
  --hicache-io-backend kernel \
  --attention-backend fa3 \
  --decode-attention-backend flashinfer \
  --chunked-prefill-size 8192 \
  --page-size 64 \
  --enable-cache-report \
  --trust-remote-code \
  --port 8000

Client:

# Sustained stress test: 64 concurrent requests per wave, looping for 120 min
python3 repro_sustained.py single \
  --url http://localhost:8000 \
  --n-requests 64 \
  --output-dir ./results
# Wrapped in a loop externally with 120-min duration target

Scripts: repro_sustained.py (sustained multi-turn stress test), repro-sglang-cuda-crash.py (single-shot crash reproduction).

Experiment This PR's fixes applied Survival Duration Total requests Result
A (baseline) neither 5.6 min 768 OK / 64 err CRASH
B write_stream.wait_stream only 20 min ~27k OK CRASH
C (this PR) both fixes 120 min 163,008 OK / 0 err SURVIVED
  • --mem-fraction-static 0.40 forces tight GPU memory, keeping token usage at 0.83 with eviction active throughout
  • Throughput comparison (experiment B vs C, same workload): ~1,350 req/min in both — zero measurable regression from wait_event
  • wait_event is a GPU-side no-op when D2H has already completed before the next forward launch (the common case)

Accuracy Tests

No changes to model forward code or kernel logic. The fix only adds GPU-side stream synchronization barriers.

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

…ransfer

Ensure write_stream waits for both schedule and forward streams before
reading KV data for D2H transfer. Without this, write_stream may read
KV slots that FA3 on forward_stream is still writing to.

Also expose last_write_finish_event so scheduler can fence forward_stream
before the next forward kernel launch.
Register forward_stream as the producer stream for HiCache controller,
and wait on last_write_finish_event in the overlap path before launching
forward_batch_generation. This prevents FA3 persistent kernel and HiCache
SM copy kernel from running concurrently on the GPU.
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the stability of the SGLang serving system, particularly when utilizing HiCache with FA3 on Hopper GPUs. By meticulously addressing two critical CUDA stream race conditions through precise GPU-side synchronization, it eliminates illegal memory access errors and system crashes that previously occurred under high eviction pressure. The changes ensure that data transfers and kernel executions are correctly ordered, preventing resource conflicts and data corruption without introducing performance overhead.

Highlights

  • CUDA Stream Synchronization: Implemented bidirectional GPU-side synchronization between HiCache's write_stream (D2H transfer) and the scheduler's forward_stream (FA3 kernel) to resolve two distinct CUDA stream race conditions.
  • D2H Read-Before-Write Prevention: Introduced write_stream.wait_stream(forward_stream) to ensure D2H transfers wait for FA3 to complete KV data writes, preventing read-before-write issues and associated illegal memory access.
  • SM Resource Conflict Resolution: Added forward_stream.wait_event(last_write_finish_event) to ensure the FA3 kernel waits for the previous D2H SM copy kernel to finish, preventing SM resource conflicts and TMA descriptor corruption on Hopper GPUs.
  • Stability and Performance Validation: Verified the fixes through extensive stress testing on Hopper GPUs, achieving 120 minutes of stable operation under high memory pressure, eliminating crashes observed in previous configurations, with no measurable performance regression.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/managers/cache_controller.py
    • Added producer_stream and last_write_finish_event attributes to HiCacheController for managing stream synchronization.
    • Implemented set_producer_stream to allow the scheduler to register the forward stream with the cache controller.
    • Modified start_writing to include write_stream.wait_stream(current_stream) and write_stream.wait_stream(producer_stream), ensuring D2H transfers wait for both the schedule and forward streams.
    • Exposed last_write_finish_event to enable external synchronization with the D2H completion.
  • python/sglang/srt/managers/scheduler.py
    • Configured HiRadixCache initialization to call set_producer_stream with self.forward_stream, linking the forward stream to the cache controller.
    • Integrated forward_stream.wait_event(last_write_finish_event) into the run_batch overlap path, ensuring the forward kernel waits for HiCache D2H transfers to complete before launching.
Activity
  • Identified and resolved two distinct CUDA stream race conditions causing illegal memory access (IMA) in FA3 persistent kernels on Hopper GPUs.
  • Utilized CUDA coredump analysis to diagnose TMA descriptor corruption, confirming two independent race conditions with different crash sites (mainloop LDGSTS vs epilogue STG).
  • Conducted rigorous stress testing on 8xH100 SXM, demonstrating a significant improvement in stability from crashing within minutes to surviving 120 minutes without errors under high memory pressure.
  • Confirmed that the implemented GPU-side synchronization (wait_event) introduces no measurable performance regression.
  • Provided detailed reproduction steps, coredump analysis commands, and stress test scripts for further validation.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses race conditions between HiCache's D2H transfers and the FA3 attention kernel by adding two CUDA stream synchronization points. The first change makes the write_stream wait for the forward_stream to prevent reading incomplete KV cache data. The second change ensures the forward_stream waits for the previous D2H transfer to finish, avoiding SM resource conflicts. The implementation in cache_controller.py and scheduler.py correctly introduces these synchronizations using wait_stream and wait_event. The logic appears sound and should resolve the reported illegal memory access issues.

@Weili-0234 Weili-0234 changed the title Fix/hicache fa3 stream sync [HiCache] Fix illegal memory access caused by stream race between FA3 forward and D2H transfer Mar 15, 2026
lawrence-harmonic added a commit to lawrence-harmonic/sglang that referenced this pull request Mar 19, 2026
@xiezhq-hermann
Copy link
Copy Markdown
Collaborator

Thank you very much for looking into this.
We have done similar profiling and actually had similar findings. It appears that the TMA registers are somehow being corrupted by concurrent execution, which is unexpected and is likely caused by undefined behavior in the driver or hardware. Or please let us know if anyone has a clearer explanation on this.
We also came up with a workaround for this issue by simply using cudaMemcpy for writes, though we have not merged that solution yet. This PR is excellent, but I think it should be limited to HiCache (kernel) + FA3. Otherwise, it may introduce performance regressions by serializing execution and I/O activity.

@CSWYF3634076
Copy link
Copy Markdown
Contributor

@Weili-0234 @xiezhq-hermann Is there a fix PR available currently?

@lawrence-harmonic
Copy link
Copy Markdown
Contributor

It appears that the TMA registers are somehow being corrupted by concurrent execution, which is unexpected and is likely caused by undefined behavior in the driver or hardware.

This was unexpected to me as well and I could not find an explanation. I did observe:

It seems that -DCUTLASS_ENABLE_GDC_FOR_SM90 (from #18756) makes the issue go away

The only change in the SASS was the addition of an "ACQBULK" instruction which is not documented by NVIDIA AFAICT.

@Zhangmj0621
Copy link
Copy Markdown

It seems that the problem showd in this pr hasn't been fixed now?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

7 participants