fix(eagle-v2): zero-fill draft KV on radix cache prefix hits to prevent NaN#19897
fix(eagle-v2): zero-fill draft KV on radix cache prefix hits to prevent NaN#19897ashtonchew wants to merge 3 commits intosgl-project:mainfrom
Conversation
…nt NaN Draft KV slots at cached-prefix positions could contain uninitialized data when req_to_token_pool maps to target-only KV. Zero-fill these slots before draft prefill forward to prevent NaN propagation. Add regression test for multi-turn and shared-prefix cache hit scenarios.
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
|
Follow-up: zero-fill vs draft KV recomputation Zero-filling is the safe fix here. It stops the crash without touching the verify path. The tradeoff is about a 15% drop in accept rate on cache-hit requests (1.53 vs ~1.8 on cold requests). The draft model still speculates using real KV for the new tokens, so this is way better than the reporter's workaround of just turning off speculation entirely (33 tok/s, no speculation at all). The better long-term fix is to actually compute the draft KV for cached-prefix positions instead of zeroing them out. That's a bigger change that touches scheduler coordination and prefill compute budgeting so I think it makes more sense to land this as the crash fix and track the recomputation as a separate issue. Let me know if you'd like me to work on that once this merges because it is a larger effort. |
|
Can you paste your reproducible script |
|
Correct fix has been merged |
Motivation
Closes #19796
The server crashes with
ValueError: Detected errors during sampling! NaN in the logitswhen a request hits a radix cache prefix (cached_tokens > 0) in Eagle V2 speculative decoding. The crash is 100% reproducible: first requests complete normally, but any subsequent request that triggers a radix cache prefix hit causes all TP workers to crash simultaneously. The crash traceback goes throughforward_batch_generation->verify->detect_nanineagle_worker_v2.py.In investigation,
req_to_token_poolslot mappings are shared between target and draft models, but the KV tensors themselves are separate pools. When the radix cache matches a prefix, the cached token slots point to valid target KV data but the corresponding draft KV slots remain uninitialized (contain garbage data). When draft prefill runs its forward pass reading from these uninitialized draft KV slots, NaN values propagate through attention into the logits.Without speculative decoding the server is fully stable, confirming the bug is isolated to the Eagle V2 draft path.
Originally reported on 8x RTX PRO 6000 Blackwell (SM120) with GLM-5-NVFP4-MTP, but the bug is architecture-independent since it stems from the draft/target KV pool separation logic.
Modifications
python/sglang/srt/speculative/eagle_worker_v2.pyAdded three instance variables to
EagleDraftWorker.__init__()to cache the KV pool layout detection result once per worker lifetime instead of checkinghasattr()on every call:_draft_kv_pool_layout_cached(bool)_draft_kv_has_split_buffers(bool, fork_buffer/v_bufferlayout)_draft_kv_has_packed_buffer(bool, forkv_bufferlayout)Added
_zero_fill_draft_kv_for_cached_prefix(self, batch: ModelWorkerBatch)method that:batch.extend_prefix_lensandbatch.req_pool_indicesto identify which requests have cached prefixesprefix_len > 0req_to_token_pool.req_to_tokenusing vectorized indexing (no.tolist()host roundtrips)torch.uniquededuplication only whenbs > 1(single-request batches cannot have cross-request slot overlap)k_buffer/v_buffer) and packed (kv_buffer) pool layoutsInserted the call to
_zero_fill_draft_kv_for_cached_prefix(batch)inside_draft_extend_for_prefill(), placed afterbatch.spec_infois set and beforeForwardBatch.init_new()andself.draft_runner.forward()so that draft KV is clean before the draft forward pass reads from it.test/registered/spec/eagle/test_eagle_cache_hit_nan.pyRegistered as a nightly CI test:
register_cuda_ci(est_time=300, suite="nightly-1-gpu", nightly=True)Server launch config:
SGLANG_ENABLE_SPEC_V2=Trueviaenvs.SGLANG_ENABLE_SPEC_V2.override(True)--speculative-algorithm EAGLE--speculative-draft-modelset toDEFAULT_DRAFT_MODEL_EAGLE--speculative-num-steps 5,--speculative-eagle-topk 1,--speculative-num-draft-tokens 6--mem-fraction-static 0.7--enable-nan-detection(triggers the crash path if fix regresses)Two test cases:
test_multiturn_cache_hit_no_nan: usesrun_multiturn_cache_hit_test()fromsglang.test.kits.cache_hit_kitwith 4 clients and 3 rounds. Asserts server liveness,total_cached_tokens > 0, andcache_hit_rate > 0.15.test_shared_prefix_cache_hit_no_nan: flushes cache, sends three requests with a shared long prefix and different suffixes. Asserts HTTP 200 for all, server liveness, andcached_tokens > 0for at least one overlap request.Accuracy Tests
This fix does not change model outputs or the verify/accept path. The fix only writes zeros into draft KV at cached-prefix slots before draft prefill forward. Target KV, token allocation, and the verify path are completely unchanged.
All testing was performed on a Vast.ai cloud instance after deploying the single-file patch to the pre-installed SGLang.
Test environment:
/sgl-workspace/sglang(main branch)meta-llama/Llama-3.1-8B-Instruct(8B params, float16)lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8BTest 1: Existing Eagle3 test suite (regression check)
Purpose: Verify the fix does not break existing speculative decoding functionality on the standard (no cache hit) code path.
Test file:
test/registered/spec/eagle/test_eagle3_basic.py-- runs an MMLU evaluation (64 examples, 32 threads) through the full Eagle3 speculative decode pipeline and asserts both accuracy (score >= 0.72) and speculation quality (avg accept length > 2.26).Server configuration:
--speculative-algorithm=EAGLE3--speculative-num-steps=2,--speculative-eagle-topk=1,--speculative-num-draft-tokens=3--dtype=float16,--chunked-prefill-size=1024,--mem-fraction-static=0.7Command:
Result:
Conclusion: No regression. MMLU score and accept length both exceeded thresholds. The
any(pl > 0)guard in the fix correctly short-circuits whenextend_prefix_lensis all zeros.Test 2: Radix cache hit reproduction (the actual bug scenario)
Purpose: Directly reproduce the NaN crash scenario from #19796, send requests with overlapping prefixes so the radix cache triggers
#cached-token > 0.A custom
test_cache_hit.pyscript sends four sequential requests:#cached-token > 0, draft attention reads KV at those cached positions. Without the fix, these are uninitialized -> NaN -> crash.Server launch:
Full output:
Key observations from server logs:
#cached-token: 661confirmed. 661 cached tokens, only 9 new tokens. Without the fix, this crashes withNaN in the logitsduring verify. With the fix, completed successfully.#cached-token: 661, completed successfully. Fix works across multiple sequential cache hits.--enable-nan-detectionactive throughout. Any NaN would have immediately crashed the server.Summary of test results:
#cached-token: 661) -- the crash triggerWhat remains to be tested:
kv_bufferbranch): Llama-3.1-8B uses MHA with separatek_buffer/v_buffer. Models using MLA (e.g., DeepSeek) would exercise thekv_bufferbranch. Code path is straightforward (same zero-fill, different buffer name) but was not tested with an actual MLA model.Additional validation:
python3 -m py_compile python/sglang/srt/speculative/eagle_worker_v2.pypassedtest_eagle_cache_hit_nan.py) added to guard against future regressions via CIBenchmarking and Profiling
.tolist()or.item()calls on the hot path).torch.uniquededuplication is applied only whenbs > 1, so single-request batches (the common case for prefill) skip it entirely._draft_kv_pool_layout_cached, avoiding repeatedhasattr()calls.Checklist