Skip to content

NSA FA3 crash with DP Attention on padded speculative batches #24233

@junliu-mde

Description

@junliu-mde

Summary

NSA FA3 can crash when DP Attention pads a local batch during speculative decoding.

The failure happens because FA3 receives padded Q rows, while the KV metadata still describes the real, unpadded batch.

Reproduction

Start an NSA model with DP Attention + speculative decoding. Example config used to reproduce:

python3 -m sglang.launch_server \
  --model-path <NSA_MODEL> \
  --tp 8 \
  --dp 4 \
  --enable-dp-attention \
  --served-model-name GLM-5-FP8-DEBUG \
  --cuda-graph-max-bs 2 \
  --speculative-algorithm NEXTN \
  --speculative-num-steps 3 \
  --speculative-eagle-topk 1 \
  --speculative-num-draft-tokens 4 \
  --page-size 1

Then send an asymmetric batch to different DP ranks. This is the core repro client:

import asyncio
import sys
import time

import httpx
import openai

BASE_URL = sys.argv[1] if len(sys.argv) > 1 else "http://localhost:30000"
PATTERN = sys.argv[2] if len(sys.argv) > 2 else "1v3"

async def send(client, dp_rank, req_id):
    try:
        t0 = time.time()
        await client.chat.completions.create(
            model="GLM-5-FP8-DEBUG",
            messages=[
                {
                    "role": "user",
                    "content": f"Write an essay about request #{req_id}, be creative.",
                }
            ],
            max_tokens=128,
            extra_body={"routed_dp_rank": dp_rank},
        )
        return True, time.time() - t0
    except Exception as e:
        return False, str(e)[:100]

async def main():
    client = openai.AsyncOpenAI(base_url=f"{BASE_URL}/v1", api_key="none")

    async with httpx.AsyncClient() as hc:
        r = await hc.get(f"{BASE_URL}/health_generate", timeout=5)
        print(f"Health: {r.status_code}")

    print("--- warmup ---")
    print("  warmup:", await send(client, 0, 0))

    dp0, dp1 = map(int, PATTERN.split("v"))
    tasks = []
    rid = 100
    for _ in range(dp0):
        tasks.append(send(client, 0, rid))
        rid += 1
    for _ in range(dp1):
        tasks.append(send(client, 1, rid))
        rid += 1

    results = await asyncio.gather(*tasks, return_exceptions=True)
    ok = sum(1 for r in results if r is not None and r[0])
    print(f"Total: {ok}/{len(results)} ok")
    for r in results:
        print(" ", r)

asyncio.run(main())

Run it with 1v3 so DP0 gets 1 request and DP1 gets 3 requests:

python3 test-dpa-repro.py http://localhost:30000 1v3

Observed Result Without Any Patch

The request batch fails and several scheduler ranks report the same FA3 error:

RuntimeError: batch_size must be equal to batch_size_k

A representative stack trace is:

File "/sgl-workspace/sglang/python/sglang/srt/layers/attention/nsa_backend.py", line 1584, in forward_decode
  return self._forward_fa3(
File "/sgl-workspace/sglang/python/sglang/srt/layers/attention/nsa_backend.py", line 1633, in _forward_fa3
  o = flash_attn_with_kvcache(
File "/usr/local/lib/python3.12/dist-packages/sgl_kernel/flash_attn.py", line 189, in flash_attn_with_kvcache
  out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default(
RuntimeError: batch_size must be equal to batch_size_k

After the first rank fails, other ranks start reporting distributed cleanup errors such as:

[c10d] recvValue failed ... Failed to recv, got 0 bytes. Connection was likely closed. Did the remote server shutdown or crash?

In Kubernetes, this shows up as failed in-flight requests followed by scheduler/model-worker process termination and pod restart.

Expected Result

Attention should run only on the real, unpadded batch. DP padding rows are only needed for the later MLP sync path and should not be treated as real FA3 Q rows.

Candidate fix:

#24235

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions