Skip to content

Add pipeline parallelism for DeepSeekV2#6434

Closed
zhjc1124 wants to merge 29 commits intosgl-project:mainfrom
zhjc1124:pp_deepseek
Closed

Add pipeline parallelism for DeepSeekV2#6434
zhjc1124 wants to merge 29 commits intosgl-project:mainfrom
zhjc1124:pp_deepseek

Conversation

@zhjc1124
Copy link
Copy Markdown
Contributor

@zhjc1124 zhjc1124 commented May 19, 2025

Motivation

#5724 #5925

Modifications

Checklist

@zhjc1124
Copy link
Copy Markdown
Contributor Author

zhjc1124 commented May 19, 2025

run test_pp_consistency

python3 -m unittest test_pp_single_node.TestDeepSeekPPAccuracy.test_pp_consistency
[DS PP Comparison] Baseline: {'accuracy': np.float64(0.83), 'latency': 22.90764766279608, 'output_throughput': 1124.8645159602909} | PP: {'accuracy': np.float64(0.83), 'latency': 20.94851907994598, 'output_throughput': 1228.153641878648}      

@zhjc1124
Copy link
Copy Markdown
Contributor Author

Launch DeepSeek-R1 with three node(tp_size=8, pp_size=3)

# node 1
python3 -m sglang.launch_server --model-path /data/modelscope/DeepSeek-R1/ --dist-init-addr 10.0.0.1:5000 --nnodes 3 --trust-remote-code --tp 8 --pp 3 --node-rank 0 --attention-backend=flashinfer
# node 2
python3 -m sglang.launch_server --model-path /data/modelscope/DeepSeek-R1/ --dist-init-addr 10.0.0.1:5000 --nnodes 3 --trust-remote-code  --tp 8 --pp 3 --node-rank 1 --attention-backend=flashinfer
# node 2
python3 -m sglang.launch_server --model-path /data/modelscope/DeepSeek-R1/ --dist-init-addr 10.0.0.1:5000 --nnodes 3 --trust-remote-code  --tp 8 --pp 3 --node-rank 2 --attention-backend=flashinfer

test bench_serving

# python3 -m sglang.bench_serving --dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json --backend sglang --model /data/modelscope/DeepSeek-R1/ --dataset-name sharegpt --num-prompts 20 --max-concurrency 1
============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Max request concurrency:                 1
Successful requests:                     20
Benchmark duration (s):                  180.25
Total input tokens:                      8203
Total generated tokens:                  4559
Total generated tokens (retokenized):    4549
Request throughput (req/s):              0.11
Input token throughput (tok/s):          45.51
Output token throughput (tok/s):         25.29
Total token throughput (tok/s):          70.80
Concurrency:                             1.00
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   9011.89
Median E2E Latency (ms):                 8842.04
---------------Time to First Token----------------
Mean TTFT (ms):                          391.64
Median TTFT (ms):                        229.77
P99 TTFT (ms):                           1958.60
---------------Inter-Token Latency----------------
Mean ITL (ms):                           38.03
Median ITL (ms):                         37.42
P95 ITL (ms):                            38.84
P99 ITL (ms):                            40.86
Max ITL (ms):                            234.84
==================================================

Copy link
Copy Markdown
Collaborator

@HaiShaw HaiShaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhjc1124 does this support PP within 1 node? Any usage example?

@zhjc1124
Copy link
Copy Markdown
Contributor Author

@zhjc1124 does this support PP within 1 node? Any usage example?

Yes.

python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct --attention-backen
d flashinfer --trust-remote-code --pp-size 2

@billishyahao
Copy link
Copy Markdown
Contributor

Tried this patch but hit the issue:

The command is :

python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct --trust-remote-code --pp-size 2

image

Meanwhile with the following command:

python3 -m sglang.launch_server --model Qwen/Qwen3-30B-A3B --pp 2

The server is up.

image

@zhjc1124
Copy link
Copy Markdown
Contributor Author

zhjc1124 commented May 23, 2025

Tried this patch but hit the issue:

The command is :

python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct --trust-remote-code --pp-size 2

image

Meanwhile with the following command:

python3 -m sglang.launch_server --model Qwen/Qwen3-30B-A3B --pp 2

The server is up.

image

Sorry for that. I lose to import Union in deepseek_v2.py when fixing conflicts. And I found there are other bugs after merging main.
Currently I fix them and launch successfully

@MichoChan
Copy link
Copy Markdown

MichoChan commented May 23, 2025

did you test with tp=2,pp=8 on 8 nodes?
i encountered an error when capture cuda graph:
fused_moe_kernel[grid]( File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 345, in <lambda> return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 691, in run kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata, File "/usr/local/lib/python3.10/dist-packages/triton/backends/nvidia/driver.py", line 365, in __call__ self.launch(*args, **kwargs) RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered

@zhjc1124 zhjc1124 requested review from BBuf and ch-wan as code owners May 23, 2025 15:14
@zhjc1124
Copy link
Copy Markdown
Contributor Author

zhjc1124 commented May 23, 2025

did you test with tp=2,pp=8 on 8 nodes? i encountered an error when capture cuda graph: fused_moe_kernel[grid]( File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 345, in <lambda> return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 691, in run kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata, File "/usr/local/lib/python3.10/dist-packages/triton/backends/nvidia/driver.py", line 365, in __call__ self.launch(*args, **kwargs) RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered

I only have 3 nodes.
I find I fail to launch DeepSeek-R1 with tp=2,pp=12 on 3 nodes.
But it seems OOM problem because I test it successfully with --cuda-graph-max-bs 1.
In CudaGraphRunner, cuda-graph-max-bs reduces the size of pp_proxy_tensors.

            # pipeline parallelism
            if self.pp_size > 1:
                self.pp_proxy_tensors = {
                    "hidden_states": torch.zeros(
                        (self.max_bs, self.model_runner.model_config.hidden_size),
                        dtype=torch.bfloat16,
                    ),
                    "residual": torch.zeros(
                        (self.max_bs, self.model_runner.model_config.hidden_size),
                        dtype=torch.bfloat16,
                    ),
                }

I also succeed to launch DeepSeek-Coder-V2-Lite-Instruct with tp=2,pp=12 on 3 nodes.
So Could you please test DeepSeek-Coder-V2-Lite-Instruct with tp=2,pp=8 on 8 nodes?
Or test more cases if DeepSeek-R1 or other big DeepSeek models works like tp=4, pp=4 or tp=8, pp=2?

@MichoChan
Copy link
Copy Markdown

did you test with tp=2,pp=8 on 8 nodes? i encountered an error when capture cuda graph: fused_moe_kernel[grid]( File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 345, in <lambda> return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 691, in run kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata, File "/usr/local/lib/python3.10/dist-packages/triton/backends/nvidia/driver.py", line 365, in __call__ self.launch(*args, **kwargs) RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered

I only have 3 nodes. I find I fail to launch DeepSeek-R1 with tp=2,pp=12 on 3 nodes. But it seems OOM problem because I test it successfully with --cuda-graph-max-bs 1. In CudaGraphRunner, cuda-graph-max-bs reduces the size of pp_proxy_tensors.

            # pipeline parallelism
            if self.pp_size > 1:
                self.pp_proxy_tensors = {
                    "hidden_states": torch.zeros(
                        (self.max_bs, self.model_runner.model_config.hidden_size),
                        dtype=torch.bfloat16,
                    ),
                    "residual": torch.zeros(
                        (self.max_bs, self.model_runner.model_config.hidden_size),
                        dtype=torch.bfloat16,
                    ),
                }

I also succeed to launch DeepSeek-Coder-V2-Lite-Instruct with tp=2,pp=12 on 3 nodes. So Could you please test DeepSeek-Coder-V2-Lite-Instruct with tp=2,pp=8 on 8 nodes? Or test more cases if DeepSeek-R1 or other big DeepSeek models works like tp=4, pp=4 or tp=8, pp=2?

DeepSeek-Coder-V2-Lite-Chat with tp=2,pp=8 on 8 nodes is ok, but DeepSeekV3 would error

@xiaobochen-amd
Copy link
Copy Markdown
Contributor

I encountered an error while testing DeepSeek-V3 on MI300X with PP=8. The issue can be reproduced as follows:

python3 -m sglang.bench_offline_throughput
--model-path /PATH/TO/DeepSeek-V3-0324
--disable-radix-cache
--trust-remote-code
--pp-size 8
--dataset-name random
--random-input-len 16384
--random-output-len 10
--random-range-ratio 1.0
--num-prompts 64

@zhjc1124
Copy link
Copy Markdown
Contributor Author

zhjc1124 commented May 27, 2025

new test case
DeepSeek-R1 with tp=4, pp=8 on 4 nodes success
DeepSeek-R1 with tp=8, pp=4 on 4 nodes success
DeepSeek-V3-0324 with tp=8, pp=4 on 4 nodes success
DeepSeek-V3-0324 with tp=4, pp=8 on 4 nodes success
DeepSeek-R1 with tp=2, pp=16 on 4 nodes fail
DeepSeek-V3-0324 with tp=2, pp=16 on 4 nodes fail

@zhjc1124
Copy link
Copy Markdown
Contributor Author

I find bug that the pp partition is unbalanced, that may cause OOM. #6666

@MichoChan
Copy link
Copy Markdown

MichoChan commented Jun 3, 2025

@zhjc1124 tp=2 would error with fused moe triton kernel, so i use enable-ep-moe, then can run successful, but i find pipeline parallelism implement now has no async for send hiddenstates,the speed is so slow compare with vllm's pipline parallelism using ray
so can we using ray for pipline parallelism ?

@fzyzcjy
Copy link
Copy Markdown
Collaborator

fzyzcjy commented Oct 16, 2025

hi could you please rebase the code

@zhjc1124
Copy link
Copy Markdown
Contributor Author

zhjc1124 commented Oct 16, 2025

hi could you please rebase the code

This PR has been included in #8846
Closed.

@zhjc1124 zhjc1124 closed this Oct 16, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants