[Intel GPU] Enable pipeline parallelism on XPU#23472
[Intel GPU] Enable pipeline parallelism on XPU#23472mingfeima merged 5 commits intosgl-project:mainfrom
Conversation
Replace hard-coded torch.cuda.{Event,current_stream,synchronize} calls in
SchedulerPPMixin with device-agnostic torch.get_device_module() lookups so
the PP scheduler loop runs on XPU (and any other non-CUDA backend) in
addition to CUDA.
Reorder send/recv in _pp_send_recv_and_preprocess_output_tensors by
pp_rank parity (even ranks send-then-recv, odd ranks recv-then-send).
The original always-send-first ordering livelocks on backends where
point-to-point isend busy-polls for a matching recv rendezvous: with
xccl on XPU every PP rank entered isend simultaneously and none posted
recv, pinning all ranks at 100% CPU inside torch.distributed. Parity
ordering guarantees each adjacent rank pair has one sender and one
receiver posted at the same time.
Verified on 4x Intel XPU with Llama-3.1-8B:
- TP=1 PP=2, PP=3, PP=4
- TP=2 PP=2
- Full warmup + bench_one_batch_server cycle at PP=4
There was a problem hiding this comment.
Code Review
This pull request refactors the pipeline parallelism scheduler to be device-agnostic by replacing CUDA-specific calls with a generic device module and updating device assignments. It also introduces a rank-parity-based ordering for send and recv operations to prevent livelocks on specific hardware backends. The review feedback highlights several issues with type hints, specifically noting that torch.Event requires PyTorch 2.4+, identifying mismatches between return type hints and actual returned values, and recommending the use of optional type markers for values that can be null.
Address review feedback on type hints in SchedulerPPMixin:
Gate the parity-based send/recv ordering in _pp_send_recv_and_preprocess_output_tensors so it applies only on XPU. CUDA/NCCL keeps the original send-first behavior since isend there is eager stream-enqueue and reordering two non-blocking ops has no effect. On XPU, isend is effectively blocking and does not return until the peer posts a matching recv; if every PP rank sends first, all ranks block waiting for a receiver and the ring deadlocks. Parity ordering (even: send->recv, odd: recv->send) guarantees each adjacent pair has one sender and one receiver posted simultaneously, and generalizes across all PP sizes (PP=2, 3, 4, ...).
|
@mingfeima Could you please review and apply the label |
|
/tag-and-rerun-ci |
|
@siju-samuel, could you resubmit this PR and let us run the related CI? The device stream sync modification might break the CUDA setup. Something like this: https://github.com/sgl-project/sglang/actions/runs/24879511440/job/72844149700 |
Motivation
Pipeline parallelism (PP) only ran on CUDA. On Intel XPU, launching any
--pp-size > 1server crashed at startup withRuntimeError: Tried to instantiate dummy base class EventbecauseSchedulerPPMixinhard-codestorch.cuda.{Event, current_stream, synchronize}. Even after fixing the hard-coded CUDA calls,PP >= 2livelocked during the first multi-rank communication: with XCCL on XPU,torch.distributed.isendbusy-polls waiting for a matchingrecvrendezvous, so when every PP rank sent before receiving, all ranks spun at 100% CPU insidetorch.distributedand none ever reached itsrecv.This PR makes PP work on XPU (and generalizes to any non-CUDA backend
torch.get_device_module()supports) without changing CUDA behavior.Modifications
Device-agnostic event/stream/sync calls in
python/sglang/srt/managers/scheduler_pp_mixin.py:torch.cuda.Event()→get_device_module().Event()torch.cuda.current_stream()→get_device_module().current_stream()torch.cuda.synchronize()→get_device_module().synchronize()deque[Tuple[torch.cuda.Event, ...]]type hints with backend-agnostic formsget_device_moduleto the existingfrom sglang.srt.utils import ...blockParity-based send/recv ordering in
_pp_send_recv_and_preprocess_output_tensors:pp_rankranks: send → recvpp_rankranks: recv → sendisendalways finds a matchingrecvalready waiting and the rendezvous completes instead of busy-spinning._do_send()/_do_recv()keep the two branches symmetric and avoid duplicating the profiler/copy-stream/d2h-event logic.No CUDA codepath changes behavior: on CUDA,
get_device_module()returnstorch.cuda, and parity ordering is a pure reordering of already-independent send and recv operations.Accuracy Tests
Verified on 4× Intel XPU with
meta-llama/Llama-3.1-8B-Instruct:TestPPAccuracy.test_logprob(TP=2 PP=2)CUDA behavior unchanged (no codepath difference for
torch.cudabackend).Speed Tests and Profiling
bench_one_batch_serverfull warmup + bench cycle at PP=4 on Intel XPU, Llama-3.1-8B,batch_size=8,input_len=1024,output_len=128:======== Warmup Begin ========
Warmup with batch_size=[8]
#Input tokens: 8192
#Output tokens: 128
batch size: 8
input_len: 1024
output_len: 16
latency: 4.04 s
input throughput: 6569.72 tok/s
output throughput: 45.89 tok/s
======== Warmup End ========
#Input tokens: 8192
#Output tokens: 1024
batch size: 8
input_len: 1024
output_len: 128
latency: 25.09 s
input throughput: 13051.32 tok/s
output throughput: 41.85 tok/s
last_ttft: 0.63 s
last generation throughput: 40.53 tok/s
Before this fix the same command hung indefinitely at the first large warmup batch (all 4 ranks at 100% CPU in
torch.distributed).Checklist
test_pp_single_nodetests cover this path; no new tests added since it's a backend-portability fix.)Review and Merge Process
/tag-and-rerun-ci,/tag-run-ci-label,/rerun-failed-ci@mingfeima @Kangyan-Zhou, @iforgetmyname, @Fridge003, @merrymercy, @ispobock, @JustinTong0323, @BBuf, @Edwardf0t1, @HaiShaw, @Ying1123, @ch-wan @hnyls2002 and @kushanam please review and merge this