Skip to content

Add NCCL/RCCL pre-warming to reduce P99 TTFT cold-start latency#20477

Merged
HaiShaw merged 4 commits intosgl-project:mainfrom
hubertlu-tw:pre_warm_nccl
Mar 17, 2026
Merged

Add NCCL/RCCL pre-warming to reduce P99 TTFT cold-start latency#20477
HaiShaw merged 4 commits intosgl-project:mainfrom
hubertlu-tw:pre_warm_nccl

Conversation

@hubertlu-tw
Copy link
Copy Markdown
Collaborator

Motivation

When using multi-GPU tensor parallelism (TP > 1), the first collective communication operation triggers NCCL/RCCL communicator initialization, causing severe P99 TTFT degradation (up to 1400ms) for the first 2-3 requests.

This PR implements NCCL/RCCL pre-warming during server startup to eliminate cold-start latency, inspired by InstantTensor's implementation.

Measured Impact on AMD MI355X:

  • P99 TTFT improvement: 74.9% (1413ms → 357ms)
  • Latency stability: 87.8% lower std dev (327ms → 40ms)
  • Warmup overhead: 4.7s one-time cost

Default Behavior:

  • Enabled by default for AMD/HIP (RCCL) - validated on MI355X
  • Disabled by default for NVIDIA/CUDA (NCCL) - pending validation

Modifications

Server Arguments (server_args.py)

Added pre_warm_nccl field with platform-aware default:

pre_warm_nccl: bool = dataclasses.field(
    default_factory=lambda: is_hip()
)  # Default: True for AMD/HIP, False for NVIDIA/CUDA

Added CLI argument:

parser.add_argument(
    "--pre-warm-nccl",
    action="store_true",
    help="Pre-warm NCCL/RCCL communicators during startup to reduce P99 TTFT cold-start latency. Default: enabled for AMD/HIP (RCCL), disabled for NVIDIA/CUDA (NCCL).",
)

Model Runner (model_runner.py)

Added warmup logic during initialization:

if self.server_args.pre_warm_nccl and (self.tp_size > 1 or self.pp_size > 1 or self.moe_ep_size > 1):
    warmup_start = time.perf_counter()
    tp_group_handle = get_tp_group().device_group

    # Single warmup all_reduce to initialize NCCL/RCCL communicator
    warmup_tensor = torch.zeros(1, device=torch.cuda.current_device())
    dist.all_reduce(warmup_tensor, group=tp_group_handle)
    torch.cuda.synchronize()

    warmup_elapsed = time.perf_counter() - warmup_start
    logger.info(f"NCCL/RCCL warmup completed in {warmup_elapsed:.3f}s ...")

Accuracy Tests

No accuracy impact - latency optimization only, does not affect model outputs.

Validated with GSM8K (100 questions):

  • Without pre-warm: 97.0%
  • With pre-warm: 98.0%

Benchmarking and Profiling

Test Environment

  • Platform: AMD MI355X (8 GPUs)
  • Model: DeepSeek-R1-MXFP4-Preview
  • Configuration: TP=8, 128-token prompts, 16 output tokens

Results

Configuration P99 TTFT P95 TTFT Median TTFT Std Dev Improvement
Without pre-warm 1413 ms 1413 ms 210 ms 327 ms Baseline
With pre-warm 357 ms 354 ms 207 ms 40 ms 74.9% faster P99

Key Findings:

  • P99 TTFT: 74.9% improvement (1413ms → 357ms)
  • Latency stability: 87.8% lower std dev (327ms → 40ms)
  • Warmup overhead: 4.7s one-time cost
  • ROI: Warmup pays for itself after 4-5 requests

Reproduction Commands

Click to expand

Test without pre-warming:

# Start server (disable via Python API: ServerArgs(pre_warm_nccl=False))
python3 -m sglang.launch_server \
  --model-path /data/DeepSeek-R1-MXFP4-Preview \
  --tp-size 8

# Send requests - first 2-3 will be slow (~1400ms)
for i in {1..20}; do
  time curl -X POST http://127.0.0.1:30000/generate \
    -H "Content-Type: application/json" \
    -d '{"text": "Hello", "sampling_params": {"max_new_tokens": 16}}'
done

Test with pre-warming (default for AMD):

# Start server (pre-warm enabled by default on AMD)
python3 -m sglang.launch_server \
  --model-path /data/DeepSeek-R1-MXFP4-Preview \
  --tp-size 8

# Expected log: "NCCL/RCCL warmup completed in 4.561s"

# Send requests - all fast (~300ms)
for i in {1..20}; do
  time curl -X POST http://127.0.0.1:30000/generate \
    -H "Content-Type: application/json" \
    -d '{"text": "Hello", "sampling_params": {"max_new_tokens": 16}}'
done

NVIDIA users (pre-warm disabled by default):

# Enable explicitly with --pre-warm-nccl
python3 -m sglang.launch_server \
  --model-path /data/model \
  --tp-size 8 \
  --pre-warm-nccl

Checklist

  • Accuracy validation: No impact on model outputs (GSM8K: 97.0% vs 98.0%)
  • Performance benchmarks: 74.9% P99 TTFT improvement on AMD MI355X
  • Code style: Follows SGLang conventions
  • Unit tests: TODO (warmup is tested via integration)
  • Documentation: Inline comments added
  • CI tests: Pending

Review Process

  1. Ping Merge Oncalls to start PR flow
  2. Get approvals from CODEOWNERS
  3. Trigger CI tests: /tag-run-ci-label, /rerun-failed-ci
  4. After green CI + approvals, merge

Implements NCCL/RCCL communicator pre-warming during server startup to
eliminate cold-start latency (up to 1400ms) for first requests when
using multi-GPU tensor parallelism.

Measured on AMD MI355X:
- P99 TTFT improvement: 74.9% (1413ms → 357ms)
- Latency stability: 87.8% lower std dev (327ms → 40ms)
- Warmup overhead: 4.7s one-time cost (5.2% of model loading)

Changes:
- server_args.py: Add pre_warm_nccl field with platform-aware default
  (enabled for AMD/HIP, disabled for NVIDIA/CUDA until validation)
- server_args.py: Add --pre-warm-nccl CLI argument
- model_runner.py: Implement warmup via single all_reduce operation
  during ModelRunner initialization

Default behavior:
- AMD/HIP: Enabled (validated 74.9% improvement)
- NVIDIA/CUDA: Disabled (pending validation)

Inspired by InstantTensor's implementation which achieved 71%
improvement on NVIDIA GPUs.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

Copy link
Copy Markdown
Collaborator

@HaiShaw HaiShaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hubertlu-tw
Can you make the change rocm/hip specific (nccl->rccl), or cuda&hip specific (to avoid regression to other platforms).

@hubertlu-tw
Copy link
Copy Markdown
Collaborator Author

@hubertlu-tw Can you make the change rocm/hip specific (nccl->rccl), or cuda&hip specific (to avoid regression to other platforms).

@HaiShaw I have modified server_args.py so that --pre-warm-nccl is only applicable for CUDA and HIP and it is set to True by default only for AMD GPUs.

@HaiShaw HaiShaw merged commit 943f34f into sgl-project:main Mar 17, 2026
71 of 91 checks passed
Wangzheee pushed a commit to Wangzheee/sglang that referenced this pull request Mar 21, 2026
0-693 pushed a commit to 0-693/sglang that referenced this pull request Mar 25, 2026
JustinTong0323 pushed a commit to JustinTong0323/sglang that referenced this pull request Apr 7, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants