[WIP][Feature] support tp-sp on qwen2/3 & deepseek v2/3/3.2#12820
[WIP][Feature] support tp-sp on qwen2/3 & deepseek v2/3/3.2#12820randgun wants to merge 1 commit intosgl-project:mainfrom
Conversation
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
| output = torch.empty( | ||
| dim_size, dtype=output_parallel.dtype, device=output_parallel.device | ||
| ) | ||
| self.tp_group.reduce_scatter_tensor(output, output_parallel.contiguous()) |
There was a problem hiding this comment.
Just wondering why we have to do reduce_scatter here in linear, communication can be handled in layer communicator
There was a problem hiding this comment.
If we do not use reduce_scatter in linear, we have to set "skip_all_reduce=enable_sp" and do reduce_scatter at both attention o_proj and mlp row parallel linear. And this code is inevitable for other models that need to adapt SP. This parameter is similar to "skip_all_reduce", maybe change name to "use_reduce_scatter" is better?
merrymercy
left a comment
There was a problem hiding this comment.
- need a review from @ch-wan
- Think about how to avoid inserting communication primitives into model forward code and make them more reusable for more models
- Add a GPU test case
| mamba_full_memory_ratio: float = 0.9 | ||
|
|
||
| # Sequence parallelism | ||
| enable_sp: bool = False |
There was a problem hiding this comment.
move to under # Runtime options
ch-wan
left a comment
There was a problem hiding this comment.
I had a quick review. Overall, I feel that the code quality can be much improved. Many changes can be simplified.
| if mlp_mode == ScatterMode.SCATTERED: | ||
| return ScatterMode.SCATTERED | ||
| if mlp_mode == ScatterMode.FULL: | ||
| return ScatterMode.TP_ATTN_FULL | ||
| raise NotImplementedError | ||
|
|
||
| @classmethod | ||
| def _compute_layer_output_mode(cls, context: _LayerModeComputationContext): | ||
| mlp_mode = cls._compute_mlp_mode(context) | ||
| def _compute_layer_output_mode( |
There was a problem hiding this comment.
Why do we need to add mlp_mode to all mode propagation? Passing context is enough. Probably mlp_mode can be a @property function of ctx.
| @@ -89,13 +90,13 @@ def __init__( | |||
| ) | |||
| self.act_fn = SiluAndMul() | |||
|
|
|||
| def forward(self, x): | |||
| if get_global_server_args().rl_on_policy_target is not None: | |||
| def forward(self, x, enable_sp: bool = False): | |||
There was a problem hiding this comment.
We can define a util function is_sp_layernorm_enabled to avoid passing this arg to multiple functions. Or we can check get_global_server_args().enable_sp in linear.py. You can refer to how we implemented is_dp_attention_enabled).
| @@ -3164,6 +2979,13 @@ def add_cli_args(parser: argparse.ArgumentParser): | |||
| help="The ratio of mamba state memory to full kv cache memory.", | |||
| ) | |||
|
|
|||
| # Sequence parallelism | |||
| parser.add_argument( | |||
| "--enable-sp", | |||
There was a problem hiding this comment.
I recommend to rename it as --enable-sp-layernorm for clarity.
| else: | ||
| forward_batch.prepare_attn_tp_scatter_input(self) | ||
| forward_batch.prepare_mlp_sync_batch( | ||
| self, get_global_server_args().enable_sp |
There was a problem hiding this comment.
We can get this server arg internally.
merrymercy
left a comment
There was a problem hiding this comment.
Let us hold this until we fix all code quality issues
Motivation
For the classic dense decode layer structure (self-attention + MLP), in the pure TP case, tensors are parallelized in the attention layer and the MLP layer. Since each device contains the full amount of data before and after each TP split, Layernorm stores 2*2BSH bytes of activation value (assume hidden states shape is [B, S, H]). SP aims to split this 4*BSH redundant data across multiple devices, and do layernorm independently. For more details please refer to the research paper https://arxiv.org/pdf/2205.05198.pdf
We add a parameter "--enable-sp" at server args, which can enable sequence parallel if be set. For example
The TP-SP has two benifits:
Modifications
NOTE:
Accuracy Tests
The test_gsam8k.py is based on benchmark/gsm8k/bench_sglang.py, you can test it with

python3 benchmark/gsm8k/bench_sglang.py --host http://127.0.0.1 --port 8000 --num-questions 300.Benchmarking and Profiling
Checklist