Skip to content

[Performance] Qwen3-Next: optimize causal_conv1d_fn triton kernel - up to 9% faster#10552

Closed
byjiang1996 wants to merge 2 commits intosgl-project:mainfrom
byjiang1996:byjiang1996/qwen3nextgpucpusync
Closed

[Performance] Qwen3-Next: optimize causal_conv1d_fn triton kernel - up to 9% faster#10552
byjiang1996 wants to merge 2 commits intosgl-project:mainfrom
byjiang1996:byjiang1996/qwen3nextgpucpusync

Conversation

@byjiang1996
Copy link
Copy Markdown
Collaborator

@byjiang1996 byjiang1996 commented Sep 17, 2025

Motivation

Currently causal_conv1d_fn triton kernel is used for extend & draft_extend. From the profile, it is inefficient - many CPU<>GPU copy sync and its grid method is very slow (takes 139us to calculate the triton grid) while the real triton kernel launch itself only takes 17us.

image

Modifications

Take this request as example:

  • seq_lens = [1, 9, 2, 17, 4]
  • BLOCK_M=8
  • num_of_blockm_s_per_seq_len = torch.tensor([1, 2, 1, 3, 1]) -> total_blockm_s = 1+2+1+3+1=8
  • m_list = [0, 1, 1, 2, 3, 3, 3, 4]: len(m_list) = 8 = total_blockm_s
  • offset_list = [0, 0, 1, 0, 0, 1, 2, 0]: len(offset_list) = 8 = total_blockm_s

Before

  • Each program_id(0) in _causal_conv1d_fwd_kernel processes one BLOCK_M out of total 8 BLOCK_M
  • Each BLOCK_M maps to one block of tokens. program_id(0) = i maps to blockm[i], maps to (m_list[i], offset_list[i]) chunk of tokens
  • For seq_lens = [1, 9, 2, 17, 4] & BLOCK_M=8, we have the following programs (aligns with m_list = [0, 1, 1, 2, 3, 3, 3, 4] and offset_list = [0, 0, 1, 0, 0, 1, 2, 0]):
    -- program_id(0) = 0
    -- program_id(0) = 1
    -- program_id(0) = 2
    -- program_id(0) = 3
    -- program_id(0) = 4
    -- program_id(0) = 5
    -- program_id(0) = 6
    -- program_id(0) = 7

After: same processing logic as in triton_backend.py's extend_attention() method

  • Each program_id(0) in _causal_conv1d_fwd_kernel processes one seq_len out of total 5 seq_lens
  • Each program_id(1) in _causal_conv1d_fwd_kernel processes one block of tokens belonging to that seq_len
  • For seq_lens = [1, 9, 2, 17, 4] & BLOCK_M=8, we have the following programs (aligns with num_of_blockm_s_per_seq_len = torch.tensor([1, 2, 1, 3, 1])):
    -- program_id(0) = 0, program_id(1) = 0
    -- program_id(0) = 1, program_id(1) = 0
    -- program_id(0) = 1, program_id(1) = 1
    -- program_id(0) = 2, program_id(1) = 0
    -- program_id(0) = 3, program_id(1) = 0
    -- program_id(0) = 3, program_id(1) = 1
    -- program_id(0) = 3, program_id(1) = 2
    -- program_id(0) = 4, program_id(1) = 0

By doing so, we don't have to calculate m_list & offset_list which requires cpu<> gpu copy and expensive tensor/numpy operations. Instead, we can just leverage forward_batch.extend_seq_len_cpu as it is in the triton kernel.

In addition to this main change, there are some misc changes done by this PR such as removing unnecessary code/tensor and improve the if condition in triton kernel

Benchmarking and Profiling

causal_conv1d_fn runtime

Before: p50=313us; avg=656us
image

After: p50=95us; avg=108us - 3X faster
image

image

GSM8k Accuracy

Config Accuracy Note
TP4 0.955 BEFORE
TP4 0.950 AFTER
TP4 MTP 0.955 BEFORE
TP4 MTP 0.950 AFTER
TP4 DP2 MTP 0.945 BEFORE
TP4 DP2 MTP 0.940 AFTER

GSM8k Output token/s

Config BEFORE (tokens/s) AFTER (tokens/s) % Faster
TP4 1613.119 1701.235 +5.5%
TP4 MTP 1888.384 2055.889 +8.9%
TP4 DP2 MTP 1762.349 1811.621 +2.8%

Checklist

@zhyncs
Copy link
Copy Markdown
Collaborator

zhyncs commented Sep 18, 2025

python3 -m sglang.launch_server --model Qwen/Qwen3-Next-80B-A3B-Instruct --tp 4
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/models/qwen3_next.py", line 861, in forward
    hidden_states, residual = layer(
                              ^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/models/qwen3_next.py", line 562, in forward
    hidden_states = self.linear_attn(
                    ^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/models/qwen3_next.py", line 475, in forward
    core_attn_out = forward_batch.attn_backend.forward(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py", line 541, in forward
    return self.forward_extend(
           ^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py", line 511, in forward_extend
    return self.attn_backend_list[1].forward_extend(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sgl-workspace/sglang/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py", line 340, in forward_extend
    mixed_qkv = causal_conv1d_fn_sgl(
                ^^^^^^^^^^^^^^^^^^^^^
TypeError: causal_conv1d_fn() got an unexpected keyword argument 'seq_lens_cpu'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants