Skip to content

Speed up by reducing _forward_extend_unified's gpu->cpu sync bubble by ~2x faster#11536

Merged
hebiao064 merged 2 commits intosgl-project:bhe/1_stage_triton_kernelfrom
byjiang1996:byjiang1996/1_stage_speedup
Oct 13, 2025
Merged

Speed up by reducing _forward_extend_unified's gpu->cpu sync bubble by ~2x faster#11536
hebiao064 merged 2 commits intosgl-project:bhe/1_stage_triton_kernelfrom
byjiang1996:byjiang1996/1_stage_speedup

Conversation

@byjiang1996
Copy link
Copy Markdown
Collaborator

Motivation

As titled.

Accuracy Tests

import sys
from sglang.srt.entrypoints.http_server import launch_server
from sglang.srt.server_args import prepare_server_args
if __name__ == "__main__":
    # Simulate CLI arguments (excluding the script name)
    args = [
        "--trust-remote-code",
        "--model-path",
        "Qwen/Qwen3-8B",
        "--attention-backend",
        "triton",
        "--tp",
        "4",
        "--enable-deterministic-inference",
        # "--disable-radix-cache",
        # "--disable-cuda-graph",
        # "--skip-server-warmup"
    ]
    server_args = prepare_server_args(args)
    launch_server(server_args)

After - Qwen3-8B:

Accuracy: 0.945
Invalid: 0.000
Latency: 15.683 s
Output throughput: 1513.949 token/s

Prompt 0 with prefix length 1: total samples: 302, Unique samples: 1
Prompt 1 with prefix length 511: total samples: 304, Unique samples: 1
Prompt 2 with prefix length 2048: total samples: 320, Unique samples: 1
Prompt 3 with prefix length 4097: total samples: 349, Unique samples: 1

After - Qwen3-30B-A3B:

Accuracy: 0.920
Invalid: 0.000
Latency: 18.446 s
Output throughput: 1222.439 token/s


Prompt 0 with prefix length 1: total samples: 300, Unique samples: 1
Prompt 1 with prefix length 511: total samples: 326, Unique samples: 1
Prompt 2 with prefix length 2048: total samples: 316, Unique samples: 1
Prompt 3 with prefix length 4097: total samples: 333, Unique samples: 1

Benchmarking and Profiling

For TP4 Qwen3-8B:

python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 32 --random-input-len 256 --random-output-len 256  --random-range-ratio 1 --profile

sglang/srt/layers/attention/triton_backend.py(833): _forward_extend_unified reduced from 3370us in avg -> 1634us in avg (~2x faster)

image image

Checklist

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @byjiang1996, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant performance optimization to the attention mechanism's _forward_extend_unified function, specifically targeting the reduction of GPU-CPU synchronization overhead. By strategically batching data transfers and leveraging CPU-side NumPy operations, the changes aim to accelerate inference speed, as demonstrated by the provided benchmarking results showing a substantial speedup.

Highlights

  • Performance Improvement: Achieved approximately a 2x speedup in the _forward_extend_unified function within the Triton attention backend.
  • Reduced GPU-CPU Synchronization: Optimized data transfer by batching all necessary scalar values to the CPU upfront, minimizing repeated synchronization calls between the GPU and CPU.
  • Vectorized Operations and CPU-side Efficiency: Utilized vectorized GPU operations for computing prefix and unified lengths, and converted CPU-bound iteration to use NumPy arrays for efficiency, further reducing synchronization overhead.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request significantly improves performance by reducing GPU-CPU synchronization within the _forward_extend_unified method. The approach of batching CPU data transfers and using NumPy for loop control is effective. My review includes a couple of suggestions to enhance code clarity: one to correct a misleading comment and another to refactor a complex conditional block for better readability. Overall, this is a solid optimization.

Comment thread python/sglang/srt/layers/attention/triton_backend.py Outdated
Comment thread python/sglang/srt/layers/attention/triton_backend.py
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@byjiang1996 byjiang1996 changed the title Speed up by reducing gpu->cpu sync bubble by ~2x faster Speed up by reducing _forward_extend_unified's gpu->cpu sync bubble by ~2x faster Oct 13, 2025
@hebiao064 hebiao064 merged commit 99393c2 into sgl-project:bhe/1_stage_triton_kernel Oct 13, 2025
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants