Skip to content

fix(eagle-v2): zero-fill draft KV on radix cache prefix hits to prevent NaN#19897

Closed
ashtonchew wants to merge 3 commits intosgl-project:mainfrom
ashtonchew:fix/eagle-v2-cache-hit-nan
Closed

fix(eagle-v2): zero-fill draft KV on radix cache prefix hits to prevent NaN#19897
ashtonchew wants to merge 3 commits intosgl-project:mainfrom
ashtonchew:fix/eagle-v2-cache-hit-nan

Conversation

@ashtonchew
Copy link
Copy Markdown
Contributor

Motivation

Closes #19796

The server crashes with ValueError: Detected errors during sampling! NaN in the logits when a request hits a radix cache prefix (cached_tokens > 0) in Eagle V2 speculative decoding. The crash is 100% reproducible: first requests complete normally, but any subsequent request that triggers a radix cache prefix hit causes all TP workers to crash simultaneously. The crash traceback goes through forward_batch_generation -> verify -> detect_nan in eagle_worker_v2.py.

In investigation, req_to_token_pool slot mappings are shared between target and draft models, but the KV tensors themselves are separate pools. When the radix cache matches a prefix, the cached token slots point to valid target KV data but the corresponding draft KV slots remain uninitialized (contain garbage data). When draft prefill runs its forward pass reading from these uninitialized draft KV slots, NaN values propagate through attention into the logits.

Without speculative decoding the server is fully stable, confirming the bug is isolated to the Eagle V2 draft path.

Originally reported on 8x RTX PRO 6000 Blackwell (SM120) with GLM-5-NVFP4-MTP, but the bug is architecture-independent since it stems from the draft/target KV pool separation logic.

Modifications

python/sglang/srt/speculative/eagle_worker_v2.py

  1. Added three instance variables to EagleDraftWorker.__init__() to cache the KV pool layout detection result once per worker lifetime instead of checking hasattr() on every call:

    • _draft_kv_pool_layout_cached (bool)
    • _draft_kv_has_split_buffers (bool, for k_buffer/v_buffer layout)
    • _draft_kv_has_packed_buffer (bool, for kv_buffer layout)
  2. Added _zero_fill_draft_kv_for_cached_prefix(self, batch: ModelWorkerBatch) method that:

    • Early-exits on idle batches, missing prefix lens, or zero-length inputs
    • Reads batch.extend_prefix_lens and batch.req_pool_indices to identify which requests have cached prefixes
    • Handles both tensor and list inputs for prefix_lens and req_pool_indices, moving them to the correct device
    • Clamps and filters to only requests with prefix_len > 0
    • Gathers slot indices from req_to_token_pool.req_to_token using vectorized indexing (no .tolist() host roundtrips)
    • Applies torch.unique deduplication only when bs > 1 (single-request batches cannot have cross-request slot overlap)
    • Zeros out all draft KV entries at the gathered slots, supporting both split (k_buffer/v_buffer) and packed (kv_buffer) pool layouts
  3. Inserted the call to _zero_fill_draft_kv_for_cached_prefix(batch) inside _draft_extend_for_prefill(), placed after batch.spec_info is set and before ForwardBatch.init_new() and self.draft_runner.forward() so that draft KV is clean before the draft forward pass reads from it.

test/registered/spec/eagle/test_eagle_cache_hit_nan.py

Registered as a nightly CI test: register_cuda_ci(est_time=300, suite="nightly-1-gpu", nightly=True)

Server launch config:

  • SGLANG_ENABLE_SPEC_V2=True via envs.SGLANG_ENABLE_SPEC_V2.override(True)
  • --speculative-algorithm EAGLE
  • --speculative-draft-model set to DEFAULT_DRAFT_MODEL_EAGLE
  • --speculative-num-steps 5, --speculative-eagle-topk 1, --speculative-num-draft-tokens 6
  • --mem-fraction-static 0.7
  • --enable-nan-detection (triggers the crash path if fix regresses)

Two test cases:

  • test_multiturn_cache_hit_no_nan: uses run_multiturn_cache_hit_test() from sglang.test.kits.cache_hit_kit with 4 clients and 3 rounds. Asserts server liveness, total_cached_tokens > 0, and cache_hit_rate > 0.15.
  • test_shared_prefix_cache_hit_no_nan: flushes cache, sends three requests with a shared long prefix and different suffixes. Asserts HTTP 200 for all, server liveness, and cached_tokens > 0 for at least one overlap request.

Accuracy Tests

This fix does not change model outputs or the verify/accept path. The fix only writes zeros into draft KV at cached-prefix slots before draft prefill forward. Target KV, token allocation, and the verify path are completely unchanged.

All testing was performed on a Vast.ai cloud instance after deploying the single-file patch to the pre-installed SGLang.

Test environment:

Component Details
Instance Vast.ai cloud GPU instance
GPUs 2x NVIDIA GeForce RTX 5090 (32 GB each, SM120 / Blackwell)
Driver NVIDIA 580.126.09
CUDA 13.0
Python 3.12.3
SGLang Source at /sgl-workspace/sglang (main branch)
Target model meta-llama/Llama-3.1-8B-Instruct (8B params, float16)
Draft model lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B
Tensor parallelism TP=2 (one model shard per GPU)

Test 1: Existing Eagle3 test suite (regression check)

Purpose: Verify the fix does not break existing speculative decoding functionality on the standard (no cache hit) code path.

Test file: test/registered/spec/eagle/test_eagle3_basic.py -- runs an MMLU evaluation (64 examples, 32 threads) through the full Eagle3 speculative decode pipeline and asserts both accuracy (score >= 0.72) and speculation quality (avg accept length > 2.26).

Server configuration:

  • --speculative-algorithm=EAGLE3
  • --speculative-num-steps=2, --speculative-eagle-topk=1, --speculative-num-draft-tokens=3
  • --dtype=float16, --chunked-prefill-size=1024, --mem-fraction-static=0.7

Command:

SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1 python -m pytest test/registered/spec/eagle/test_eagle3_basic.py -x -v

Result:

test/registered/spec/eagle/test_eagle3_basic.py::TestEagle3Basic::test_mmlu PASSED [100%]
1 passed, 1 warning in 400.65s (0:06:40)

Conclusion: No regression. MMLU score and accept length both exceeded thresholds. The any(pl > 0) guard in the fix correctly short-circuits when extend_prefix_lens is all zeros.

Test 2: Radix cache hit reproduction (the actual bug scenario)

Purpose: Directly reproduce the NaN crash scenario from #19796, send requests with overlapping prefixes so the radix cache triggers #cached-token > 0.

A custom test_cache_hit.py script sends four sequential requests:

  1. Request A -- cold request, unique prompt, no cache hit. Baseline sanity check.
  2. Request B -- uses a long shared prefix (~670 tokens). Populates the radix cache.
  3. Request C -- same shared prefix + different suffix. This is the crash trigger. Radix cache sets #cached-token > 0, draft attention reads KV at those cached positions. Without the fix, these are uninitialized -> NaN -> crash.
  4. Request D -- another cache hit with the same prefix, different suffix. Verifies fix works across multiple cache-hit requests.

Server launch:

SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1 python -m sglang.launch_server \
  --model-path meta-llama/Llama-3.1-8B-Instruct \
  --speculative-algorithm EAGLE3 \
  --speculative-draft-model-path lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B \
  --speculative-num-steps 2 --speculative-eagle-topk 1 --speculative-num-draft-tokens 3 \
  --dtype float16 --mem-fraction-static 0.6 --enable-nan-detection \
  --port 30000 --host 127.0.0.1 --tensor-parallel-size 2 --disable-cuda-graph

Full output:

============================================================
Eagle V2 Radix Cache Hit NaN Reproduction Test
============================================================
Server is ready!

--- Sending Request A (cold, no cache hit) ---
  Prompt length: ~61 chars
  Status: OK
  Output: The capital of France is Paris, which has a rich and fascinating history ...

Server alive after A: True
  Accept length: 1.5

--- Sending Request B (populates cache with shared prefix) ---
  Prompt length: ~4018 chars
  Status: OK
  Output: The Roman Empire, one of the most influential civilizations in human history ...

Server alive after B: True
  Accept length: 1.5076923076923077

--- Sending Request C (CACHE HIT - this is the crash trigger) ---
  Prompt length: ~4055 chars
  [Server log: Prefill batch, #new-seq: 1, #new-token: 9, #cached-token: 661, ...]
  [Server log: Decode batch, #running-req: 1, #token: 719, accept len: 1.55, ...]
  Status: OK
  Output: Ancient Egypt, one of the most fascinating and enduring civilizations ...

Server alive after C: True
  Accept length: 1.525

--- Sending Request D (another cache hit, extra verification) ---
  Prompt length: ~4055 chars
  [Server log: Prefill batch, #new-seq: 1, #new-token: 9, #cached-token: 661, ...]
  [Server log: Decode batch, #running-req: 1, #token: 719, accept len: 1.55, accept rate: 0.52, ...]
  Status: OK
  Output: How did their system of governance influence modern democracy? The ancient Greeks...

  Accept length: 1.53125

============================================================
ALL PASSED: No NaN crash on radix cache prefix hits!
============================================================

Key observations from server logs:

  • Request C (crash trigger): #cached-token: 661 confirmed. 661 cached tokens, only 9 new tokens. Without the fix, this crashes with NaN in the logits during verify. With the fix, completed successfully.
  • Request D (second cache hit): Also #cached-token: 661, completed successfully. Fix works across multiple sequential cache hits.
  • Accept length ~1.53: Slightly lower than typical (~1.8-2.0) because the draft model's KV for cached prefix positions is zero-filled rather than computed. This is expected -- draft quality only affects speculation, not final results.
  • No NaN detected: --enable-nan-detection active throughout. Any NaN would have immediately crashed the server.

Summary of test results:

Test What it validates Result
Eagle3 basic (MMLU) No regression on standard path, accuracy + accept rate thresholds PASSED
Cache hit Request A Cold request, no cache hit, baseline sanity PASSED
Cache hit Request B Populates radix cache with shared prefix PASSED
Cache hit Request C Cache hit (#cached-token: 661) -- the crash trigger PASSED
Cache hit Request D Second sequential cache hit -- extra verification PASSED

What remains to be tested:

  • Original bug report hardware (8x RTX PRO 6000 Blackwell 96GB, SM120, GLM-5-NVFP4-MTP with NEXTN): not available. However, the bug is architecture-level (uninitialized draft KV on shared slot indices), not hardware-specific. Validated on same GPU generation (RTX 5090, also SM120/Blackwell) with a different model.
  • MLA KV pool layout (kv_buffer branch): Llama-3.1-8B uses MHA with separate k_buffer/v_buffer. Models using MLA (e.g., DeepSeek) would exercise the kv_buffer branch. Code path is straightforward (same zero-fill, different buffer name) but was not tested with an actual MLA model.
  • Concurrent/batched cache-hit requests: test sent requests sequentially. The code iterates per-request with independent slot lists so correctness is expected.

Additional validation:

  • python3 -m py_compile python/sglang/srt/speculative/eagle_worker_v2.py passed
  • Pre-commit hooks (isort, ruff, black, codespell) all passed on commit
  • Nightly regression test (test_eagle_cache_hit_nan.py) added to guard against future regressions via CI

Benchmarking and Profiling

  • The zero-fill operation is O(cached_slots * num_draft_layers) per prefill batch, which is negligible compared to the draft model forward pass cost.
  • Slot collection uses device-native tensor indexing with no host-device synchronization (no .tolist() or .item() calls on the hot path).
  • torch.unique deduplication is applied only when bs > 1, so single-request batches (the common case for prefill) skip it entirely.
  • KV pool layout detection is cached once per worker lifetime via _draft_kv_pool_layout_cached, avoiding repeated hasattr() calls.
  • No mandatory micro-benchmark is required: the overhead is bounded by a single indexed tensor write per layer, which is orders of magnitude cheaper than the draft forward pass.

Checklist

…nt NaN

Draft KV slots at cached-prefix positions could contain uninitialized
data when req_to_token_pool maps to target-only KV. Zero-fill these
slots before draft prefill forward to prevent NaN propagation.

Add regression test for multi-turn and shared-prefix cache hit scenarios.
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@ashtonchew
Copy link
Copy Markdown
Contributor Author

Follow-up: zero-fill vs draft KV recomputation

Zero-filling is the safe fix here. It stops the crash without touching the verify path. The tradeoff is about a 15% drop in accept rate on cache-hit requests (1.53 vs ~1.8 on cold requests). The draft model still speculates using real KV for the new tokens, so this is way better than the reporter's workaround of just turning off speculation entirely (33 tok/s, no speculation at all).

The better long-term fix is to actually compute the draft KV for cached-prefix positions instead of zeroing them out. That's a bigger change that touches scheduler coordination and prefill compute budgeting so I think it makes more sense to land this as the crash fix and track the recomputation as a separate issue. Let me know if you'd like me to work on that once this merges because it is a larger effort.

@kpham-sgl
Copy link
Copy Markdown
Collaborator

Can you paste your reproducible script test_cache_hit.py in a gist. I cannot replicate this issue on B200 using your server launch script.

SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1 python -m sglang.launch_server \
  --model-path meta-llama/Llama-3.1-8B-Instruct \
  --speculative-algorithm EAGLE3 \
  --speculative-draft-model-path lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B \
  --speculative-num-steps 2 --speculative-eagle-topk 1 --speculative-num-draft-tokens 3 \
  --dtype float16 --mem-fraction-static 0.6 --enable-nan-detection \
  --port 30000 --host 127.0.0.1 --tensor-parallel-size 2 --disable-cuda-graph

@b8zhong
Copy link
Copy Markdown
Collaborator

b8zhong commented Mar 28, 2026

Correct fix has been merged

@b8zhong b8zhong closed this Mar 28, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] Eagle V2 speculative decoding crashes with NaN in logits when radix cache prefix hit occurs (SM120 / 8 * RTX PRO 6000 Blackwell)

3 participants