Skip to content

PD disagg with NIXL Connector: GDN support (Qwen3.5)#41869

Merged
NickLucche merged 2 commits into
vllm-project:mainfrom
ZhanqiuHu:nixl-gdn-support
May 14, 2026
Merged

PD disagg with NIXL Connector: GDN support (Qwen3.5)#41869
NickLucche merged 2 commits into
vllm-project:mainfrom
ZhanqiuHu:nixl-gdn-support

Conversation

@ZhanqiuHu

@ZhanqiuHu ZhanqiuHu commented May 6, 2026

Copy link
Copy Markdown
Contributor

Closes #41886

Summary

  • Add GDN (Gated Delta Net) conv-state layout support for NIXL KV transfer
  • Fix heterogeneous TP kernel block matching for mamba hybrid models
  • Handle physical_blocks_per_logical mismatch between P and D in disaggregated serving

Test plan

  • Unit tests for GDN conv-state split derivation (test_derive_mamba_conv_split)
  • Unit tests for kernel block ID mapping (test_logical_to_remote_kernel_block_ids) moved to another PR
  • E2e accuracy tests: Qwen3.5-0.8B GSM8K across 9 TP configs
  • Added Qwen3.5-0.8B to hybrid_ssm_configs in Buildkite CI

Accuracy Results

Model: Qwen/Qwen3.5-0.8B (GDN)
Benchmark: GSM8K exact_match,strict-match (5-shot)
Standalone baseline: ~0.323

Serve command

VLLM_SSM_CONV_STATE_LAYOUT=DS \
vllm serve Qwen/Qwen3.5-0.8B \
  --enforce-eager \
  --block-size 128 \
  --gpu-memory-utilization 0.3 \
  --max-model-len 8192 \
  --trust-remote-code \
  --no-disable-hybrid-kv-cache-manager \
  --no-async-scheduling \
  --tensor-parallel-size <TP> \
  --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
  • Attention backend: FLASH_ATTN (auto-selected)
  • User input block size: 128
  • Final block size (after _align_hybrid_block_size):
    • TP=1, TP=2: 640 tokens (physical_per_logical=10, kernel_block_size=64)
    • TP=4: 384 tokens (physical_per_logical=6, kernel_block_size=64)

Results

Config P_TP D_TP Score
1P1D 1 1 0.335
1P2D 1 2 0.328
2P1D 2 1 0.331
2P2D 2 2 0.321
4P1D 4 1 0.328
4P2D 4 2 0.327
1P4D 1 4 0.322
2P4D 2 4 0.329
4P4D 4 4 0.330

All 9 configurations pass within expected range of the standalone baseline (0.323 ± 0.03).

@mergify mergify Bot added qwen Related to Qwen models kv-connector labels May 6, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request generalizes the KV transfer mechanism for SSM-based models to support GDN (Gated Delta Net) in addition to Mamba2. Key changes include refactoring MambaConvSplitInfo to support generic sub-projections, updating the NIXL worker to handle block ID trimming for Mamba hybrid models under heterogeneous TP configurations, and simplifying block mapping logic. I have no feedback to provide.

@ZhanqiuHu ZhanqiuHu changed the title [WIP] NIXL Connector: GDN support (Qwen3.5) [WIP] PD disagg with NIXL Connector: GDN support (Qwen3.5) May 6, 2026
@ZhanqiuHu ZhanqiuHu marked this pull request as ready for review May 6, 2026 21:23

@claude claude Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@ZhanqiuHu ZhanqiuHu changed the title [WIP] PD disagg with NIXL Connector: GDN support (Qwen3.5) PD disagg with NIXL Connector: GDN support (Qwen3.5) May 6, 2026
@ZhanqiuHu

Copy link
Copy Markdown
Contributor Author

Reproduction Script

Usage

The config string (e.g. 1p1d, 2p1d) specifies TP sizes: P_TP x D_TP.
Use --num-prefill / --num-decode for multiple instances.

# Run from vllm repo root

# No P/D (single server):
python run_lm_eval.py standalone --model Qwen/Qwen3.5-0.8B --gpu-ids 0

# P_TP=1, D_TP=1 (2 GPUs):
python run_lm_eval.py 1p1d --model Qwen/Qwen3.5-0.8B --gpu-ids 0,1

# P_TP=2, D_TP=2 (4 GPUs):
python run_lm_eval.py 2p2d --model Qwen/Qwen3.5-0.8B --gpu-ids 0,1,2,3

# P_TP=1, D_TP=1, 2 prefill + 2 decode instances (4 GPUs):
python run_lm_eval.py 1p1d --model Qwen/Qwen3.5-0.8B --gpu-ids 0,1,2,3 \
    --num-prefill 2 --num-decode 2

# Quick sanity (single prompt, no lm_eval):
python run_lm_eval.py 1p1d --model Qwen/Qwen3.5-0.8B --gpu-ids 0,1 --quick

# Custom block size / attention backend:
python run_lm_eval.py 1p1d --model Qwen/Qwen3.5-0.8B --gpu-ids 0,1 \
    --block-size 64 --backend FLASH_ATTN

Requirements: vllm (this branch), lm_eval

Script: run_lm_eval.py
#!/usr/bin/env python3
"""lm_eval accuracy test for vLLM PD disaggregation.

Launches vLLM server(s) with NIXL KV transfer and runs lm_eval gsm8k 5-shot.
No external dependencies beyond vllm and lm_eval.

The config string (e.g. 1p1d, 2p1d) specifies TP sizes: P_TP x D_TP.
Use --num-prefill / --num-decode for multiple instances.

Usage:
  # No P/D (single server):
  python run_lm_eval.py standalone --model Qwen/Qwen3.5-0.8B --gpu-ids 0

  # P_TP=1, D_TP=1 (2 GPUs):
  python run_lm_eval.py 1p1d --model Qwen/Qwen3.5-0.8B --gpu-ids 0,1

  # P_TP=2, D_TP=2 (4 GPUs):
  python run_lm_eval.py 2p2d --model Qwen/Qwen3.5-0.8B --gpu-ids 0,1,2,3

  # P_TP=1, D_TP=1, 2 prefill + 2 decode instances (4 GPUs):
  python run_lm_eval.py 1p1d --model Qwen/Qwen3.5-0.8B --gpu-ids 0,1,2,3 \
      --num-prefill 2 --num-decode 2

  # Quick sanity (single prompt, no lm_eval):
  python run_lm_eval.py 1p1d --model Qwen/Qwen3.5-0.8B --gpu-ids 0,1 --quick

  # Custom block size / attention backend:
  python run_lm_eval.py 1p1d --model Qwen/Qwen3.5-0.8B --gpu-ids 0,1 \
      --block-size 64 --backend FLASH_ATTN
"""
from __future__ import annotations

import argparse
import atexit
import json
import os
import re
import selectors
import shlex
import signal
import socket
import subprocess
import sys
import threading
import time
from datetime import datetime
from pathlib import Path
from urllib.error import URLError
from urllib.request import Request, urlopen

# ── Model configs ─────────────────────────────────────────────────────────

MODELS = {
    "Qwen/Qwen3.5-0.8B": {
        "max_model_len": 8192,
        "block_size": 128,
        "gpu_mem_util": 0.3,
        "expected_gsm8k": 0.323,
        "hma": True,
    },
    "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8": {
        "max_model_len": 8192,
        "block_size": 128,
        "gpu_mem_util": 0.8,
        "expected_gsm8k": 0.84,
        "hma": True,
    },
}

# ── Constants ─────────────────────────────────────────────────────────────

VLLM_ROOT = Path(os.environ.get("VLLM_ROOT", ".")).resolve()
PROXY_SCRIPT = str(
    VLLM_ROOT / "tests/v1/kv_connector/nixl_integration/toy_proxy_server.py"
)

TASK = "gsm8k"
NUM_FEWSHOT = 5
NUM_CONCURRENT = 100
FILTER = "exact_match,strict-match"
RTOL = 0.03

_child_procs: list[subprocess.Popen] = []
_server_procs: list[subprocess.Popen] = []
_cleanup_done = False

# ── Terminal colors ───────────────────────────────────────────────────────

_RESET = "\033[0m"
_BOLD = "\033[1m"
_DIM = "\033[2m"
_RED = "\033[31m"
_GREEN = "\033[32m"
_YELLOW = "\033[33m"
_BLUE = "\033[34m"
_MAGENTA = "\033[35m"
_CYAN = "\033[36m"

_LABEL_COLORS = {"P": _BLUE, "D": _MAGENTA, "PROXY": _CYAN, "S": _GREEN}

_NOISE_PATTERNS = (
    "Route: /", "Methods: ", "Loading safetensors",
    "Autotuning process", "autotuner.py",
    "Waiting for application startup", "Application startup complete",
    "Started server process", "compilation.py",
    "Enabled custom fusions", "non-default args",
)


def log(msg: str, color: str = ""):
    ts = datetime.now().strftime("%H:%M:%S.%f")[:-3]
    reset = _RESET if color else ""
    print(f"{_BOLD}[{ts}]{_RESET} {color}{msg}{reset}", flush=True)


# ── Utilities ─────────────────────────────────────────────────────────────

def find_free_ports(n: int) -> list[int]:
    socks, ports = [], []
    for _ in range(n):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        s.bind(("", 0))
        ports.append(s.getsockname()[1])
        socks.append(s)
    for s in socks:
        s.close()
    return ports


def _kill_tree(pid: int, grace_seconds: float = 0.5):
    try:
        os.killpg(os.getpgid(pid), signal.SIGTERM)
        time.sleep(grace_seconds)
        os.killpg(os.getpgid(pid), signal.SIGKILL)
    except (ProcessLookupError, PermissionError, OSError):
        pass
    try:
        os.kill(pid, signal.SIGKILL)
    except (ProcessLookupError, PermissionError):
        pass


def cleanup():
    global _cleanup_done
    if _cleanup_done:
        return
    _cleanup_done = True
    log("Cleaning up servers...")
    for proc in _child_procs:
        _kill_tree(proc.pid)
        try:
            proc.wait(timeout=10)
        except subprocess.TimeoutExpired:
            pass
    _child_procs.clear()
    _server_procs.clear()
    log("Cleanup done.")


def _colorize(line: str, label: str) -> str | None:
    lc = label.upper()
    base_lc = lc.rstrip("0123456789")
    label_color = _LABEL_COLORS.get(base_lc, _LABEL_COLORS.get(lc, ""))
    colored_prefix = f"{label_color}{_BOLD}[{label}]{_RESET}"

    if any(p in line for p in _NOISE_PATTERNS):
        return None
    if "ERROR" in line or "Traceback" in line:
        return f"  {colored_prefix} {_RED}{_BOLD}{line}{_RESET}"
    if "WARNING" in line:
        return f"  {colored_prefix} {_YELLOW}{line}{_RESET}"
    if "ready" in line.lower() or "DONE" in line:
        return f"  {colored_prefix} {_GREEN}{line}{_RESET}"
    return f"  {colored_prefix} {_DIM}{line}{_RESET}"


def start_process(cmd: list[str], env: dict, log_file: str,
                  label: str, is_server: bool = True) -> subprocess.Popen:
    full_env = os.environ.copy()
    full_env.update(env)
    fh = open(log_file, "w")
    proc = subprocess.Popen(
        cmd, env=full_env,
        stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
        start_new_session=True,
    )

    def _tee(proc, fh, label):
        try:
            for raw in iter(proc.stdout.readline, b""):
                line = raw.decode("utf-8", errors="replace")
                fh.write(line)
                fh.flush()
                colored = _colorize(line, label)
                if colored is not None:
                    print(colored, end="", flush=True)
        except (ValueError, OSError):
            pass
        finally:
            fh.close()

    threading.Thread(target=_tee, args=(proc, fh, label), daemon=True).start()
    _child_procs.append(proc)
    if is_server:
        _server_procs.append(proc)
    return proc


def wait_for_server(port: int, timeout: int = 600) -> bool:
    log(f"Waiting for server on port {port}...")
    start = time.time()
    while time.time() - start < timeout:
        try:
            resp = urlopen(f"http://localhost:{port}/v1/models", timeout=5)
            body = resp.read().decode()
            if '"id"' in body:
                elapsed = int(time.time() - start)
                log(f"Server on port {port} ready ({elapsed}s)")
                return True
        except (URLError, OSError, TimeoutError):
            pass
        time.sleep(5)
    log(f"TIMEOUT: Server on port {port} not ready after {timeout}s", _RED)
    return False


def check_server_health() -> bool:
    for proc in _server_procs:
        if proc.poll() is not None:
            log(f"Server process pid={proc.pid} died (exit={proc.returncode})",
                _RED)
            return False
    return True


def parse_config(config: str) -> tuple[int, int]:
    m = re.match(r"^(\d+)p(\d+)d$", config.lower())
    if not m:
        sys.exit(f"Unknown config: {config}. Use format <N>p<M>d, e.g. 1p1d")
    return int(m.group(1)), int(m.group(2))


# ── Eval ──────────────────────────────────────────────────────────────────

def run_quick_sanity(eval_url: str, model_name: str) -> tuple[bool, str]:
    prompt = "The capital of France is"
    payload = json.dumps({
        "model": model_name, "prompt": prompt,
        "max_tokens": 30, "temperature": 0.0,
    }).encode()
    req = Request(f"{eval_url}/completions", data=payload,
                  headers={"Content-Type": "application/json"})
    try:
        resp = urlopen(req, timeout=60)
        body = json.loads(resp.read().decode())
        text = body["choices"][0]["text"].strip()
        log(f"  Prompt:     \"{prompt}\"")
        log(f"  Completion: \"{text}\"")
        is_coherent = len(text) > 0 and not all(c in ' \n\t|.' for c in text)
        return is_coherent, text
    except Exception as e:
        log(f"Quick sanity FAILED: {e}", _RED)
        return False, str(e)


def run_lm_eval(base_url: str, model_name: str, log_file: str,
                limit: int | None = None) -> float | None:
    model_args = (
        f"model={model_name},"
        f"base_url={base_url}/completions,"
        f"num_concurrent={NUM_CONCURRENT},"
        f"tokenized_requests=False"
    )
    cmd = [
        sys.executable, "-m", "lm_eval",
        "--model", "local-completions",
        "--model_args", model_args,
        "--tasks", TASK,
        "--num_fewshot", str(NUM_FEWSHOT),
        "--output_path", log_file.replace(".log", ""),
        "--gen_kwargs", "temperature=0.0",
    ]
    if limit is not None:
        cmd.extend(["--limit", str(limit)])

    log(f"Running lm_eval: {' '.join(cmd)}")
    proc = subprocess.Popen(cmd, stdout=subprocess.PIPE,
                            stderr=subprocess.STDOUT)
    _child_procs.append(proc)

    deadline = time.time() + 600
    output_lines = []
    sel = selectors.DefaultSelector()
    sel.register(proc.stdout, selectors.EVENT_READ)
    timed_out = False

    while True:
        remaining = deadline - time.time()
        if remaining <= 0:
            timed_out = True
            break
        events = sel.select(timeout=min(remaining, 5.0))
        if events:
            raw = proc.stdout.readline()
            if not raw:
                break
            line = raw.decode("utf-8", errors="replace")
            output_lines.append(line)
            print(line, end="", flush=True)
        else:
            if not check_server_health():
                timed_out = True
                break
        if proc.poll() is not None:
            for raw in iter(proc.stdout.readline, b""):
                line = raw.decode("utf-8", errors="replace")
                output_lines.append(line)
                print(line, end="", flush=True)
            break
    sel.close()

    if timed_out:
        log("lm_eval timed out or server died — killing", _RED)
        proc.kill()
        proc.wait()
        return None

    proc.wait()
    full_output = "".join(output_lines)
    with open(log_file, "w") as f:
        f.write(full_output)

    if proc.returncode != 0:
        log(f"lm_eval exited with code {proc.returncode}", _RED)
        return None

    for line in output_lines:
        if "strict-match" in line and "exact_match" in line:
            parts = [p.strip() for p in line.split("|")]
            for part in parts:
                if part.startswith("0.") or part.startswith("1."):
                    try:
                        score = float(part)
                        if 0.0 <= score <= 1.0:
                            log(f"Parsed gsm8k strict-match score: {score}")
                            return score
                    except ValueError:
                        continue

    log("Could not parse lm_eval score from output", _YELLOW)
    return None


def scrape_cache_hit_rate(decoder_ports: list[int]) -> tuple[bool, float, str]:
    total_queries = total_hits = 0.0
    for port in decoder_ports:
        try:
            resp = urlopen(f"http://localhost:{port}/metrics", timeout=10)
            body = resp.read().decode()
            for line in body.split("\n"):
                if line.startswith("#") or not line.strip():
                    continue
                m = re.match(r'^(\S+?)(?:\{[^}]*\})?\s+([\d.eE+\-]+)', line)
                if m:
                    name, val = m.group(1), float(m.group(2))
                    if "prefix_cache_queries" in name:
                        total_queries += val
                    elif "prefix_cache_hits" in name:
                        total_hits += val
        except Exception:
            pass

    if total_queries == 0:
        return False, 0.0, "No cache queries recorded"
    rate = total_hits / total_queries
    return rate >= 0.99, rate, f"hits={total_hits:.0f}/{total_queries:.0f}"


# ── Main ──────────────────────────────────────────────────────────────────

def main():
    parser = argparse.ArgumentParser(
        description=__doc__,
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument("config",
                        help="'standalone' or P/D config: 1p1d, 2p2d, etc.")
    parser.add_argument("--model", required=True,
                        choices=list(MODELS.keys()),
                        help="Model name")
    parser.add_argument("--gpu-ids", required=True,
                        help="Comma-separated GPU IDs (e.g. '0,1')")
    parser.add_argument("--tp", type=int, default=None,
                        help="TP for standalone mode (default: 1)")
    parser.add_argument("--num-prefill", type=int, default=1,
                        help="Number of prefill instances (default: 1)")
    parser.add_argument("--num-decode", type=int, default=1,
                        help="Number of decode instances (default: 1)")
    parser.add_argument("--quick", action="store_true",
                        help="Quick sanity only (single prompt, no lm_eval)")
    parser.add_argument("--limit", type=int, default=None,
                        help="Run only N lm_eval examples")
    parser.add_argument("--block-size", type=int, default=None,
                        help="Override block size (default: from model config)")
    parser.add_argument("--backend", default=None,
                        help="Attention backend override "
                             "(e.g. FLASHINFER, FLASH_ATTN, TRITON_ATTN)")
    parser.add_argument("--extra-serve-args", default=None,
                        help="Extra args for 'vllm serve'")
    args = parser.parse_args()

    model_name = args.model
    model_cfg = MODELS[model_name]
    block_size = args.block_size or model_cfg["block_size"]
    is_standalone = args.config.lower() == "standalone"
    gpu_ids = [int(g) for g in args.gpu_ids.split(",")]

    if is_standalone:
        tp = args.tp or 1
        p_tp = d_tp = None
        num_p = num_d = 0
        needed = tp
    else:
        p_tp, d_tp = parse_config(args.config)
        tp = None
        num_p = args.num_prefill
        num_d = args.num_decode
        needed = num_p * p_tp + num_d * d_tp

    if len(gpu_ids) < needed:
        sys.exit(f"Need {needed} GPUs, got {len(gpu_ids)}")

    atexit.register(cleanup)
    signal.signal(signal.SIGINT, lambda *_: (cleanup(), sys.exit(130)))
    signal.signal(signal.SIGTERM, lambda *_: (cleanup(), sys.exit(143)))

    results_dir = f"/tmp/lm_eval_{args.config}_{datetime.now():%Y%m%d_%H%M%S}"
    os.makedirs(results_dir, exist_ok=True)

    log("=" * 60)
    if is_standalone:
        log(f"Model: {model_name} | Config: standalone TP={tp}")
    else:
        log(f"Model: {model_name} | Config: {args.config} "
            f"P_TP={p_tp} x{num_p}, D_TP={d_tp} x{num_d}")
    log(f"block_size: {block_size} | GPUs: {gpu_ids} | Results: {results_dir}")
    log("=" * 60)

    hma_flag = "--no-disable-hybrid-kv-cache-manager" if model_cfg["hma"] else ""
    decoder_ports: list[int] = []

    base_serve_args = [
        sys.executable, "-m", "vllm.entrypoints.openai.api_server",
        "--model", model_name,
        "--enforce-eager",
        "--block-size", str(block_size),
        "--gpu-memory-utilization", str(model_cfg["gpu_mem_util"]),
        "--max-model-len", str(model_cfg["max_model_len"]),
        "--trust-remote-code",
        "--no-async-scheduling",
    ]
    if hma_flag:
        base_serve_args.append(hma_flag)
    if args.backend:
        base_serve_args.extend(["--override-attention-backend", args.backend])
    if args.extra_serve_args:
        base_serve_args.extend(shlex.split(args.extra_serve_args))

    base_env = {"VLLM_SSM_CONV_STATE_LAYOUT": "DS"}

    if is_standalone:
        ports = find_free_ports(1)
        cmd = base_serve_args + ["--port", str(ports[0]),
                                 "--tensor-parallel-size", str(tp)]
        env = {**base_env,
               "CUDA_VISIBLE_DEVICES": ",".join(str(g) for g in gpu_ids)}
        log(f"Starting standalone server on port {ports[0]}")
        start_process(cmd, env, f"{results_dir}/server.log", "S")
        if not wait_for_server(ports[0]):
            sys.exit("Server did not start")
        eval_url = f"http://localhost:{ports[0]}/v1"
    else:
        kv_config = json.dumps({
            "kv_connector": "NixlConnector", "kv_role": "kv_both",
        })
        pd_args = base_serve_args + ["--kv-transfer-config", kv_config]

        total_instances = num_p + num_d
        all_ports = find_free_ports(total_instances + 1 + total_instances)
        proxy_port = all_ports[total_instances]
        sc_ports = all_ports[total_instances + 1:]

        gpu_cursor = 0
        p_ports = []
        d_ports = []

        for i in range(num_p):
            p_port = all_ports[i]
            p_ports.append(p_port)
            inst_gpus = gpu_ids[gpu_cursor:gpu_cursor + p_tp]
            gpu_cursor += p_tp

            p_cmd = pd_args + ["--port", str(p_port),
                               "--tensor-parallel-size", str(p_tp)]
            p_env = {**base_env,
                     "CUDA_VISIBLE_DEVICES": ",".join(str(g) for g in inst_gpus),
                     "VLLM_KV_CACHE_LAYOUT": "HND",
                     "VLLM_NIXL_SIDE_CHANNEL_PORT": str(sc_ports[i]),
                     "UCX_TLS": "tcp,cuda_ipc,cuda_copy,self"}
            label = f"P{i}" if num_p > 1 else "P"
            log_name = f"prefiller_{i}.log" if num_p > 1 else "prefiller.log"
            log(f"Starting prefiller {i} on port {p_port} "
                f"(GPUs {inst_gpus}, TP={p_tp})")
            start_process(p_cmd, p_env, f"{results_dir}/{log_name}", label)

        for i in range(num_d):
            d_port = all_ports[num_p + i]
            d_ports.append(d_port)
            decoder_ports.append(d_port)
            inst_gpus = gpu_ids[gpu_cursor:gpu_cursor + d_tp]
            gpu_cursor += d_tp

            d_cmd = pd_args + ["--port", str(d_port),
                               "--tensor-parallel-size", str(d_tp)]
            d_env = {**base_env,
                     "CUDA_VISIBLE_DEVICES": ",".join(str(g) for g in inst_gpus),
                     "VLLM_KV_CACHE_LAYOUT": "HND",
                     "VLLM_NIXL_SIDE_CHANNEL_PORT": str(sc_ports[num_p + i]),
                     "UCX_TLS": "tcp,cuda_ipc,cuda_copy,self"}
            label = f"D{i}" if num_d > 1 else "D"
            log_name = f"decoder_{i}.log" if num_d > 1 else "decoder.log"
            log(f"Starting decoder {i} on port {d_port} "
                f"(GPUs {inst_gpus}, TP={d_tp})")
            start_process(d_cmd, d_env, f"{results_dir}/{log_name}", label)

        for i, port in enumerate(p_ports):
            if not wait_for_server(port):
                sys.exit(f"Prefiller {i} did not start")
        for i, port in enumerate(d_ports):
            if not wait_for_server(port):
                sys.exit(f"Decoder {i} did not start")

        proxy_cmd = [
            sys.executable, PROXY_SCRIPT,
            "--port", str(proxy_port),
            "--prefiller-hosts", *["localhost"] * len(p_ports),
            "--prefiller-ports", *[str(p) for p in p_ports],
            "--decoder-hosts", *["localhost"] * len(d_ports),
            "--decoder-ports", *[str(p) for p in d_ports],
        ]
        log(f"Starting proxy on port {proxy_port} "
            f"(P ports: {p_ports}, D ports: {d_ports})")
        start_process(proxy_cmd, {}, f"{results_dir}/proxy.log", "PROXY")
        time.sleep(3)
        eval_url = f"http://localhost:{proxy_port}/v1"

    # ── Quick sanity ──────────────────────────────────────────────────
    log("\nRunning quick sanity check...")
    ok, text = run_quick_sanity(eval_url, model_name)
    if not ok:
        log("Sanity check FAILED — server may be unhealthy", _RED)
        sys.exit(1)
    log(f"Sanity check passed", _GREEN)

    if args.quick:
        log("--quick mode: done (skipping lm_eval)")
        return

    # ── lm_eval ───────────────────────────────────────────────────────
    log(f"\nRunning lm_eval gsm8k {NUM_FEWSHOT}-shot...")
    score = run_lm_eval(eval_url, model_name,
                        f"{results_dir}/lm_eval.log",
                        limit=args.limit)

    # ── Results ───────────────────────────────────────────────────────
    expected = model_cfg["expected_gsm8k"]
    log("\n" + "=" * 60)
    if score is not None:
        log(f"  Score: {score:.4f}")
        if expected is not None:
            diff = abs(score - expected)
            passed = diff <= RTOL
            status = f"{_GREEN}PASS{_RESET}" if passed else f"{_RED}FAIL{_RESET}"
            log(f"  Expected: {expected:.4f} ± {RTOL}")
            log(f"  Diff: {diff:.4f}{status}")
        else:
            log(f"  No expected baseline (informational)")
            passed = True
    else:
        log(f"  {_RED}lm_eval did not return a score{_RESET}")
        passed = False

    if decoder_ports:
        hit_ok, hit_rate, hit_detail = scrape_cache_hit_rate(decoder_ports)
        log(f"  Cache hit rate: {hit_rate:.2%} ({hit_detail}) "
            f"→ {'PASS' if hit_ok else 'FAIL'}")
        passed = passed and hit_ok

    log("=" * 60)
    if passed:
        log(f"  {_GREEN}{_BOLD}ALL CHECKS PASSED{_RESET}")
    else:
        log(f"  {_RED}{_BOLD}SOME CHECKS FAILED{_RESET}")
        sys.exit(1)


if __name__ == "__main__":
    try:
        main()
    finally:
        cleanup()

@NickLucche NickLucche left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I see the issue with prefix caching, thanks for the work @ZhanqiuHu .
Could you split that bugfix on a separate PR though? I think it's important enough that we should land it separately to better reference and isolate it.

# TODO: (NickLucche) Address async scheduling issue with TP>1 separately as this may impact other models.
"VLLM_SSM_CONV_STATE_LAYOUT=DS ENABLE_HMA_FLAG=1 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=ibm-granite/granite-4.0-h-tiny VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code,--no-async-scheduling"
# GDN (Qwen3.5)
"VLLM_SSM_CONV_STATE_LAYOUT=DS ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=Qwen/Qwen3.5-0.8B VLLM_SERVE_EXTRA_ARGS=--no-async-scheduling"

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can do async sched with TP 1

Comment thread vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py
Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
@ZhanqiuHu ZhanqiuHu requested a review from Harry-Chen as a code owner May 12, 2026 14:46
@NickLucche NickLucche added the ready ONLY add when PR is ready to merge/full CI is needed label May 13, 2026

@NickLucche NickLucche left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

# P-sized offsets. Scale down by |tp_ratio|.
abs_ratio = -tp_ratio
remote_conv0 = conv0 // abs_ratio
remote_conv1 = conv1 // abs_ratio

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this another fix?

@NickLucche NickLucche merged commit 24337fb into vllm-project:main May 14, 2026
70 checks passed
mfylcek pushed a commit to mfylcek/vllm that referenced this pull request May 19, 2026
)

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
jhu960213 pushed a commit to jhu960213/vllm that referenced this pull request May 20, 2026
)

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
h1t35h pushed a commit to h1t35h/vllm that referenced this pull request May 21, 2026
)

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
Liuweixiong0118 pushed a commit to Liuweixiong0118/vllm that referenced this pull request Jun 1, 2026
)

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
Signed-off-by: Liuweixiong0118 <lwx34158427@gmail.com>
mvanhorn pushed a commit to mvanhorn/vllm that referenced this pull request Jun 4, 2026
)

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
Signed-off-by: Matt Van Horn <455140+mvanhorn@users.noreply.github.com>
@ZhanqiuHu ZhanqiuHu deleted the nixl-gdn-support branch June 4, 2026 17:45
andakai pushed a commit to andakai/vllm that referenced this pull request Jun 4, 2026
)

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
knight0528 pushed a commit to knight0528/vllm that referenced this pull request Jun 8, 2026
)

Signed-off-by: Zhanqiu Hu <zhu@redhat.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build kv-connector qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature]: NIXL P/D Disaggregation: GDN support (Qwen3.5)

2 participants