Skip to content

[WIP][Feature] support tp-sp on qwen2/3 & deepseek v2/3/3.2#12820

Open
randgun wants to merge 1 commit intosgl-project:mainfrom
randgun:new_sp
Open

[WIP][Feature] support tp-sp on qwen2/3 & deepseek v2/3/3.2#12820
randgun wants to merge 1 commit intosgl-project:mainfrom
randgun:new_sp

Conversation

@randgun
Copy link
Copy Markdown
Contributor

@randgun randgun commented Nov 7, 2025

Motivation

For the classic dense decode layer structure (self-attention + MLP), in the pure TP case, tensors are parallelized in the attention layer and the MLP layer. Since each device contains the full amount of data before and after each TP split, Layernorm stores 2*2BSH bytes of activation value (assume hidden states shape is [B, S, H]). SP aims to split this 4*BSH redundant data across multiple devices, and do layernorm independently. For more details please refer to the research paper https://arxiv.org/pdf/2205.05198.pdf

屏幕截图 2025-11-10 114901

We add a parameter "--enable-sp" at server args, which can enable sequence parallel if be set. For example

python3 -m sglang.launch_server --model-path $MODEL_PATH \
        --tp-size 16 --dp-size 1 --enable-sp \
        --trust-remote-code --attention-backend ascend --device npu --host $HOST_IP --port $PORT \
        --quantization w8a8_int8 --mem-fraction-static 0.8 \
        --chunked-prefill-size 16000 --context-length 16000 --max-prefill-tokens 16000 --max-total-tokens 16000 \
        --disable-radix-cache --moe-a2a-backend deepep --deepep-mode auto

The TP-SP has two benifits:

  1. Reduce long dataset (36K) TTFT for 10% on deepseek v3 and 7% on deepseek v3.2.
  2. Reduce peak memory because of less activations on RMSNorm layer.

Modifications

  1. For the RowParallel linear, we replace all-reduce comm op with reduce-scatter on TP group if enable sp.
  2. For the dense layers of model (qwen2/3 & deepseek v2/3/3.2), we split residual before layernorm at the first layer, hidden states is be splitted at RowParallel linear so do not need to split again.
  3. After the first dense layer, all data are scattered state, only do layernorm at prepare_attn and prepare_mlp.
  4. Before the MLP and ATTENTION module, we do extra all-gather because of weights has been splitted on tensor dimension, to make sure inputs are complete.
  5. Do all-gather at the last dense layer.
  6. For sparse layer of deepseek models, when enableed Deepep, the hidden states has been sscattered, we utilize it and move the _scattered_to_tp_attn_full from prepare_attention to after getting the q_lora and latent cache, which also can decrease much more computation.

NOTE:

  1. We only adapt qwen2/3 and deepseek v2/3/3.2 on ascend backend, for other backends you can add few code to adapt SP.
  2. TP-SP can enable with CP ([Ascend] Deepseek v3 and v3.2 support Context Parallelism #12207) together on Ascend backend. We test it on dsv3.2 model, CP=16, TP=2, which can continue reducing TTFT for 1 percent and decrease peak runtime memory (2G for single device).

Accuracy Tests

The test_gsam8k.py is based on benchmark/gsm8k/bench_sglang.py, you can test it with
python3 benchmark/gsm8k/bench_sglang.py --host http://127.0.0.1 --port 8000 --num-questions 300.
image

Benchmarking and Profiling

export cann_path=/usr/local/Ascend/ascend-toolkit/latest
source /usr/local/Ascend/driver/bin/setenv.bash
source ${cann_path}/../set_env.sh
source ${cann_path}/../../nnal/atb/set_env.sh
source ${cann_path}/opp/vendors/customize/bin/set_env.bash
export ASCEND_HOME_PATH=${cann_path}

# CPU high preformance
echo performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor
sysctl -w vm.swappiness=0
sysctl -w kernel.numa_balancing=0
sysctl -w kernel.sched_migration_cost_ns=50000

export SGLANG_SET_CPU_AFFINITY=1

# Memory Fragmentation
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
export STREAMS_PER_DEVICE=32

# HCCL
export SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=16
export HCCL_BUFFSIZE=1600
#export HCCL_RDMA_PCIE_DIRECT_POST_NOSTRICT=TRUE
export HCCL_OP_EXPANSION_MODE=AIV
export HCCL_ALGO="level0:NA;level1:ring"

# Your NIC
export HCCL_SOCKET_IFNAME=enp48s3u1u1
export GLOO_SOCKET_IFNAME=enp48s3u1u1

export PYTHONPATH=$PWD/python/:$PYTHONPATH

python -m sglang.launch_server --model-path $MODEL_PATH \
        --tp-size 16 --dp-size 1 --enable-sp \
        --trust-remote-code --attention-backend ascend --device npu --host 127.0.0.1 --port 8000 \
        --quantization w8a8_int8 --mem-fraction-static 0.79 \
        --chunked-prefill-size 36000 --context-length 36000 --max-prefill-tokens 36000 --max-total-tokens 36000 \
        --disable-radix-cache --moe-a2a-backend deepep --deepep-mode auto
image

Checklist

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@randgun randgun marked this pull request as draft November 7, 2025 07:31
@randgun randgun changed the title support tp-sp on qwen2/3 & deepseek v2/3 [feat] support tp-sp on qwen2/3 & deepseek v2/3 Nov 10, 2025
@randgun randgun changed the title [feat] support tp-sp on qwen2/3 & deepseek v2/3 [feat] support tp-sp on qwen2/3 & deepseek v2/3/3.2 Nov 10, 2025
@randgun randgun changed the title [feat] support tp-sp on qwen2/3 & deepseek v2/3/3.2 [Feature] support tp-sp on qwen2/3 & deepseek v2/3/3.2 Nov 10, 2025
@randgun randgun changed the title [Feature] support tp-sp on qwen2/3 & deepseek v2/3/3.2 [WIP][Feature] support tp-sp on qwen2/3 & deepseek v2/3/3.2 Nov 10, 2025
@randgun randgun marked this pull request as ready for review November 10, 2025 07:46
@randgun randgun requested a review from yizhang2077 as a code owner November 15, 2025 14:26
output = torch.empty(
dim_size, dtype=output_parallel.dtype, device=output_parallel.device
)
self.tp_group.reduce_scatter_tensor(output, output_parallel.contiguous())
Copy link
Copy Markdown
Collaborator

@iforgetmyname iforgetmyname Nov 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering why we have to do reduce_scatter here in linear, communication can be handled in layer communicator

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we do not use reduce_scatter in linear, we have to set "skip_all_reduce=enable_sp" and do reduce_scatter at both attention o_proj and mlp row parallel linear. And this code is inevitable for other models that need to adapt SP. This parameter is similar to "skip_all_reduce", maybe change name to "use_reduce_scatter" is better?

Comment thread python/sglang/srt/layers/communicator.py Outdated
Comment thread python/sglang/srt/layers/linear.py Outdated
Copy link
Copy Markdown
Contributor

@merrymercy merrymercy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. need a review from @ch-wan
  2. Think about how to avoid inserting communication primitives into model forward code and make them more reusable for more models
  3. Add a GPU test case

mamba_full_memory_ratio: float = 0.9

# Sequence parallelism
enable_sp: bool = False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to under # Runtime options

Copy link
Copy Markdown
Collaborator

@ch-wan ch-wan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a quick review. Overall, I feel that the code quality can be much improved. Many changes can be simplified.

if mlp_mode == ScatterMode.SCATTERED:
return ScatterMode.SCATTERED
if mlp_mode == ScatterMode.FULL:
return ScatterMode.TP_ATTN_FULL
raise NotImplementedError

@classmethod
def _compute_layer_output_mode(cls, context: _LayerModeComputationContext):
mlp_mode = cls._compute_mlp_mode(context)
def _compute_layer_output_mode(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to add mlp_mode to all mode propagation? Passing context is enough. Probably mlp_mode can be a @property function of ctx.

@@ -89,13 +90,13 @@ def __init__(
)
self.act_fn = SiluAndMul()

def forward(self, x):
if get_global_server_args().rl_on_policy_target is not None:
def forward(self, x, enable_sp: bool = False):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can define a util function is_sp_layernorm_enabled to avoid passing this arg to multiple functions. Or we can check get_global_server_args().enable_sp in linear.py. You can refer to how we implemented is_dp_attention_enabled).

@@ -3164,6 +2979,13 @@ def add_cli_args(parser: argparse.ArgumentParser):
help="The ratio of mamba state memory to full kv cache memory.",
)

# Sequence parallelism
parser.add_argument(
"--enable-sp",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend to rename it as --enable-sp-layernorm for clarity.

else:
forward_batch.prepare_attn_tp_scatter_input(self)
forward_batch.prepare_mlp_sync_batch(
self, get_global_server_args().enable_sp
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can get this server arg internally.

Copy link
Copy Markdown
Contributor

@merrymercy merrymercy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let us hold this until we fix all code quality issues

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants