Skip to content

NaN in hidden states with modelopt_fp4 quantization under concurrent load (GLM-5-NVFP4-MTP on Blackwell) #20043

@voipmonitor

Description

@voipmonitor

Summary

GLM-5-NVFP4-MTP with --quantization modelopt_fp4 produces NaN values in model hidden states under concurrent request load, causing torch.multinomial to crash with probability tensor contains either inf, nan or element < 0.

This is not related to speculative decoding, radix cache, or any specific attention backend — the NaN originates directly from the model's transformer layers.

Environment

  • Hardware: 8× NVIDIA RTX PRO 6000 Blackwell Server Edition (SM120, 96GB each)
  • Model: GLM-5-NVFP4-MTP (--quantization modelopt_fp4)
  • SGLang: Built from main (commit 346a4131cfbac5)
  • PyTorch: 2.7+ (nightly, Blackwell support)
  • Docker: Custom image based on nvcr.io/nvidia/pytorch:25.03-py3
  • CUDA: 12.8+

Launch command

python3 -m sglang.launch_server \
  --model-path /mnt/GLM-5-NVFP4-MTP \
  --tp 8 \
  --trust-remote-code \
  --attention-backend flashinfer \
  --moe-runner-backend cutlass \
  --kv-cache-dtype bf16 \
  --quantization modelopt_fp4 \
  --disable-custom-all-reduce \
  --enable-flashinfer-allreduce-fusion \
  --mem-fraction-static 0.88 \
  --cuda-graph-max-bs 32 \
  --host 0.0.0.0 --port 5000 \
  --served-model-name glm-5 \
  --max-running-requests 64 \
  --enable-metrics

Reproducer

The bug requires concurrent requests — serial requests don't trigger it. Save this as test_nan_concurrent.py:

"""
Concurrent reproducer for NaN crash with modelopt_fp4 quantization.

Sends bursts of 64 concurrent requests with a shared prefix.
Typically crashes within 2-5 rounds without --enable-nan-detection.

Usage:
    python3 test_nan_concurrent.py --url http://localhost:5000 --concurrency 64
"""

import argparse, json, sys, time, requests
from concurrent.futures import ThreadPoolExecutor, as_completed

def send(base, prefix, suffix, model, i):
    payload = {
        "model": model,
        "messages": [{"role": "user", "content": prefix + suffix}],
        "max_tokens": 64,
        "temperature": 0.7,
    }
    try:
        r = requests.post(f"{base}/v1/chat/completions", json=payload, timeout=180)
        r.raise_for_status()
        return i, True, ""
    except Exception as e:
        return i, False, str(e)[:200]

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--url", default="http://localhost:5000")
    parser.add_argument("--concurrency", type=int, default=64)
    parser.add_argument("--rounds", type=int, default=0, help="0=infinite")
    parser.add_argument("--model", default="glm-5")
    parser.add_argument("--prefix-repeat", type=int, default=60)
    args = parser.parse_args()

    prefix = (
        "You are an expert assistant specializing in science, technology, "
        "and history. You provide detailed, accurate, and well-structured "
        "answers. Always cite relevant facts and provide context. "
    ) * args.prefix_repeat

    suffixes = [
        "Explain the theory of general relativity in simple terms.",
        "What caused the fall of the Roman Empire?",
        "How do mRNA vaccines work?",
        "Describe the process of nuclear fusion in stars.",
        "What are the main differences between TCP and UDP?",
        "Explain quantum entanglement to a 10-year-old.",
        "What is the significance of the Turing test?",
        "How does photosynthesis work at the molecular level?",
    ]

    # Flush cache and seed
    try:
        requests.post(f"{args.url}/flush_cache", timeout=30)
    except:
        pass
    time.sleep(0.5)
    send(args.url, prefix, "Say hello.", args.model, 0)
    print(f"Running bursts of {args.concurrency} concurrent requests...")

    rnd = 0
    max_rounds = args.rounds if args.rounds > 0 else float("inf")
    t0 = time.monotonic()

    while rnd < max_rounds:
        rnd += 1
        with ThreadPoolExecutor(max_workers=args.concurrency) as pool:
            futs = [
                pool.submit(send, args.url, prefix, suffixes[i % len(suffixes)], args.model, i)
                for i in range(args.concurrency)
            ]
            fails = sum(1 for f in as_completed(futs) if not f.result()[1])
        elapsed = time.monotonic() - t0
        if fails:
            print(f"Round {rnd}: {fails}/{args.concurrency} FAILED after {elapsed:.0f}s")
            print("Check server logs for NaN tracebacks.")
            sys.exit(1)
        if rnd % 10 == 0:
            print(f"Round {rnd}: OK ({elapsed:.0f}s)")

    print(f"Completed {rnd} rounds without crash in {time.monotonic()-t0:.0f}s")

if __name__ == "__main__":
    main()

Run: python3 test_nan_concurrent.py --url http://localhost:5000 --concurrency 64

Crash traceback

/pytorch/aten/src/ATen/native/cuda/TensorCompare.cu:109: _assert_async_cuda_kernel: block: [0,0,0], thread: [0,0,0]
Assertion `probability tensor contains either `inf`, `nan` or element < 0` failed.

Traceback (most recent call last):
  File "scheduler.py", line 1211, in event_loop_overlap
    batch_result = self.run_batch(batch)
  File "scheduler.py", line 2370, in run_batch
    batch_result = self.model_worker.forward_batch_generation(...)
  File "tp_worker.py", line 500, in forward_batch_generation
    batch_result.next_token_ids = self.model_runner.sample(...)
  File "model_runner.py", line 2570, in sample
    next_token_ids = self.sampler(...)
  File "sampler.py", line 161, in forward
    batch_next_token_ids = self._sample_from_probs(...)
  File "sampler.py", line 202, in _sample_from_probs
    batch_next_token_ids = sampling_from_probs_torch(...)
  File "sampler.py", line 600, in sampling_from_probs_torch
    sampled_index = torch.multinomial(probs, num_samples=1)
torch.AcceleratorError: CUDA error: device-side assert triggered

Root cause analysis

We instrumented the code with synchronous NaN checks at multiple stages to trace exactly where NaN originates:

1. NaN comes from model hidden_states, NOT from sampling/softmax

We added a NaN check in DeepseekV2ForCausalLM.forward() (which GlmMoeDsaForCausalLM inherits) before logits_processor:

!!! DS_MODEL hidden_states NaN before logits_processor: 1007616/7403520 elements, shape=torch.Size([1205, 6144])

~13.6% of hidden state elements are NaN in affected batches. This is not a single-token numerical edge case — entire sequences within the batch have NaN hidden states.

2. NaN propagates through the pipeline

hidden_states (NaN) → lm_head matmul → logits (NaN) → softmax → probs (NaN) → multinomial → CRASH

3. It's NOT related to speculative decoding

Crash reproduces identically without any speculative decoding args (--speculative-eagle-model etc. removed).

4. It's NOT related to radix cache

Crash reproduces with --disable-radix-cache.

5. High concurrency is required

  • Serial requests: 50+ rounds, no crash
  • 8 concurrent: Crashes after ~5 runs of the reproducer
  • 64 concurrent: Crashes within 2-5 rounds consistently

This suggests the NaN may be related to batch size / memory access patterns rather than a cache race condition.

6. --enable-nan-detection is an effective workaround

With --enable-nan-detection, the sampler replaces NaN logits with -1e5 (effectively zero probability after softmax). The affected tokens are simply not selected.

Stress test result with --enable-nan-detection:

  • 200 rounds × 64 concurrent requests = 12,800 requests over 21 minutes — ZERO crashes
  • 3,136 NaN detections were logged during this run (NaN still occurs but is handled gracefully)

What we investigated (not the cause)

Hypothesis Result
Speculative decoding (EAGLE-v2) ❌ Crashes without spec decoding
Radix cache prefix hit race ❌ Crashes with --disable-radix-cache
Softmax numerical instability ❌ NaN already present before softmax
Temperature division by zero ❌ Temperature is 0.7 (valid)
top_k/top_p renormalization ❌ NaN present before these are called
SGLANG_SPEC_NAN_DETECTION=1 ⚠️ Uses torch._assert_async — same generic CUDA assert message, doesn't help isolate the source

Likely root cause

The NaN originates in the transformer layers of the model with modelopt_fp4 quantization. With 13.6% of hidden state elements being NaN in affected batches, this points to:

  1. FP4 dequantization numerical instability under specific input patterns that only manifest with larger batch sizes
  2. Possible race condition in quantized weight access when multiple sequences are processed concurrently
  3. Kernel-level issue in the cutlass MoE runner or attention computation with FP4 weights

Workaround

Add --enable-nan-detection to the launch command. This replaces NaN logits with -1e5 before sampling, preventing the crash. Quality impact appears minimal since the NaN tokens receive ~0 probability and valid tokens are selected instead.

Related

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions