Skip to content

[Feature] improve TBO: two chunk overlap#8144

Merged
ch-wan merged 15 commits intosgl-project:mainfrom
House-West:two_chunk_overlap
Aug 6, 2025
Merged

[Feature] improve TBO: two chunk overlap#8144
ch-wan merged 15 commits intosgl-project:mainfrom
House-West:two_chunk_overlap

Conversation

@House-West
Copy link
Copy Markdown
Contributor

@House-West House-West commented Jul 18, 2025

Motivation

mentioned in #6328

I run some benchmark on 2 x 8 x H800.

1. Special Case (one request with a length of 3072 per dp)

Experiment setup

SGLANG_TBO_DEBUG=1 SGL_CHUNKED_PREFIX_CACHE_THRESHOLD=1 python3 -m sglang.launch_server \
    --model-path /dev/shm/DeepSeek-V3-0324 --tp 16 --dp 16 --dist-init-addr $MASTER_IP:6676 \
    --chunked-prefill-size 65536 --max-prefill-tokens 170000 --nnodes 2 --node-rank $FED_POD_INDEX \
    --trust-remote-code --deepep-mode normal --enable-dp-attention --enable-deepep-moe \
    --enable-two-batch-overlap --disable-overlap-schedule --disable-radix-cache --disable-cuda-graph \
    --max-total-tokens 34000 --mem-fraction-static 0.90  --max-running-requests 16 \
    --tbo-token-distribution-threshold 0.48

python3 -m sglang.bench_serving \
        --backend sglang --host 127.0.0.1 --port 30000 \
        --dataset-name random \
        --num-prompt 1024 \
        --random-input 3072 --random-output 1 --random-range-ratio 1 \
        --max-concurrency 1024

For baseline and this PR, change tbo-token-distribution-threshold, 0.0 indicates disable two-chunk-overlap, threshold > 0 indicates enable two-chunk-overlap.

two-chunk-overlap will trigger chunk prefill, SGL_CHUNKED_PREFIX_CACHE_THRESHOLD=1 can improve performance.

The bench-serving script is repeated 5 times, and throw away the 1st run (because it contains JIT compilation etc).

Experiment result

Throughput

  • baseline: 64820.55, 64901.24, 63900.35, 63931.12
  • ours: 72391.92, 72845.28, 71511.55, 73152.61

On average, it improves 12.56% throughput.

2. General Case (variable length inputs (30-3072))

Experiment setup

SGLANG_TBO_DEBUG=1 SGL_CHUNKED_PREFIX_CACHE_THRESHOLD=1 python3 -m sglang.launch_server \
    --model-path /dev/shm/DeepSeek-V3-0324 --tp 16 --dp 16 --dist-init-addr $MASTER_IP:6676 \
    --chunked-prefill-size 65536 --max-prefill-tokens 170000 --nnodes 2 --node-rank $FED_POD_INDEX \
    --trust-remote-code --deepep-mode normal --enable-dp-attention --enable-deepep-moe \
    --enable-two-batch-overlap --disable-overlap-schedule --disable-radix-cache --disable-cuda-graph \
    --max-total-tokens 34000 --mem-fraction-static 0.90  --max-running-requests 512 \
    --tbo-token-distribution-threshold 0.48

python3 -m sglang.bench_serving \
        --backend sglang --host 127.0.0.1 --port 30000 \
        --dataset-name random \
        --num-prompt 2048 \
        --random-input 3072 --random-output 1 --random-range-ratio 0.01 \
        --max-concurrency 1024

Experiment result

Throughput

  • baseline: 71695.06, 68585.39, 75784.86, 77115.82
  • ours: 77532.36, 77059.32, 75654.72, 78041.45

On average, it improves 5.15% throughput. The range of performance improvement depends on the distribution of input lengths

Modifications

Checklist

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @House-West, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant performance enhancement by implementing a 'two-chunk-overlap' strategy within the existing Two-Batch Overlap (TBO) system. This optimization aims to improve throughput for large language model serving by more efficiently distributing tokens across micro-batches, particularly beneficial for specific hardware setups. The changes include new configuration options, refined batch splitting algorithms, and general code improvements to support this new overlapping mechanism.

Highlights

  • Performance Optimization: Introduced 'two-chunk-overlap' functionality within the Two-Batch Overlap (TBO) mechanism to improve throughput, especially for specific hardware configurations like 2x8xH800. Benchmarks show up to 12.56% throughput improvement for special cases (one request with a length of 3072 per dp) and 5.15% for general variable-length inputs (seq_len 0-3k).
  • Configurability: Added a new server argument --two-batch-token-distribution-threshlod (defaulting to 0.48) to control the token distribution between the two batches in micro-batch overlap. This threshold determines whether to perform a standard two-batch overlap or the new two-chunk overlap, with a value of 0 disabling two-chunk overlap.
  • Batch Splitting Logic Enhancement: Refactored and enhanced the batch splitting logic within the TBO system. The system now intelligently determines whether to split sequences by count or by token distribution (aiming for a half-sum split) based on the new two_batch_token_distribution_threshlod, ensuring optimal token allocation for overlapped processing.
  • Code Refactoring for Position Computation: Centralized the compute_position logic into a single helper function in forward_batch_info.py. This change abstracts away the conditional choice between Triton and PyTorch implementations, reducing code duplication and improving maintainability.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@fzyzcjy fzyzcjy self-assigned this Jul 18, 2025
@fzyzcjy fzyzcjy self-requested a review July 18, 2025 09:23
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The code changes introduce a new feature to improve TBO by overlapping two chunks, and also includes some unit tests. There are a few typos that need to be addressed.

Comment thread python/sglang/srt/two_batch_overlap.py Outdated
Comment thread python/sglang/srt/two_batch_overlap.py Outdated
Comment thread python/sglang/srt/server_args.py Outdated
Comment thread python/sglang/srt/server_args.py Outdated
Comment thread python/sglang/srt/managers/schedule_batch.py Outdated
Comment thread python/sglang/srt/server_args.py Outdated
Comment thread python/sglang/srt/model_executor/forward_batch_info.py Outdated
Comment thread python/sglang/srt/two_batch_overlap.py Outdated
@fzyzcjy
Copy link
Copy Markdown
Collaborator

fzyzcjy commented Jul 19, 2025

core is not reviewed yet, firstly glanced at the non-core things

Comment thread python/sglang/srt/two_batch_overlap.py Outdated
Comment thread python/sglang/srt/two_batch_overlap.py Outdated
Comment thread python/sglang/srt/two_batch_overlap.py Outdated
Comment thread python/sglang/srt/two_batch_overlap.py Outdated
Comment thread python/sglang/srt/two_batch_overlap.py Outdated
Comment thread python/sglang/srt/two_batch_overlap.py Outdated
Comment thread python/sglang/srt/two_batch_overlap.py Outdated
Comment thread python/sglang/srt/two_batch_overlap.py Outdated
Comment thread python/sglang/srt/two_batch_overlap.py Outdated
Comment thread python/sglang/srt/two_batch_overlap.py Outdated
@House-West House-West force-pushed the two_chunk_overlap branch 7 times, most recently from 0625971 to 2a56702 Compare July 20, 2025 15:49
@House-West House-West requested a review from fzyzcjy July 21, 2025 03:31
Comment thread python/sglang/srt/two_batch_overlap.py Outdated
Copy link
Copy Markdown
Collaborator

@fzyzcjy fzyzcjy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I will check again later)

Comment thread test/srt/test_two_batch_overlap.py Outdated
Comment thread python/sglang/srt/two_batch_overlap.py Outdated
Comment thread python/sglang/srt/two_batch_overlap.py Outdated
Comment thread python/sglang/srt/two_batch_overlap.py Outdated
Comment thread python/sglang/srt/two_batch_overlap.py Outdated
@fzyzcjy
Copy link
Copy Markdown
Collaborator

fzyzcjy commented Jul 23, 2025

only that tiny nit and then ready to merge

@fzyzcjy
Copy link
Copy Markdown
Collaborator

fzyzcjy commented Jul 23, 2025

LGTM, now only need to wait for CI green

@House-West
Copy link
Copy Markdown
Contributor Author

@zhyncs Could you view this pr, waiting approval to merge

@ch-wan
Copy link
Copy Markdown
Collaborator

ch-wan commented Jul 30, 2025

@House-West Nice work! For the first case, I wonder to know where the performance gain comes from. Is two chunk overlap equivalent to two batch overlap under that case? Thanks.

@House-West
Copy link
Copy Markdown
Contributor Author

House-West commented Jul 30, 2025

@House-West Nice work! For the first case, I wonder to know where the performance gain comes from. Is two chunk overlap equivalent to two batch overlap under that case? Thanks.

@ch-wan as mentioned in #6328. When each dp has only one request, two-chunk-overlap and two- batch-overlap are not equivalent.
such as batch_size = 1, extend_seq_len = [3072], extend_prefix_len = [0]

In two- batch-overlap:

  • micro batch0: extend_seq_len = [3072], extend_prefix_len = [0]
  • micro batch1 (idle batch): extend_seq_len = [0], extend_prefix_len = [0]

In two- chunk-overlap:

  • micro batch0: extend_seq_len = [1536], extend_prefix_len = [0]
  • micro batch1: extend_seq_len = [1536], extend_prefix_len = [1536]

Compared to two-batch-overlap, the latency of group gemm, dispatch, and combine operations of the two micro batches is close in two-chunk-overlap, which is better for overlapping.

@ch-wan ch-wan added the ready-to-merge The PR is ready to merge after the CI is green. label Aug 1, 2025
Copy link
Copy Markdown
Collaborator

@ch-wan ch-wan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a small CI test to test_dp_attention.py?

@House-West
Copy link
Copy Markdown
Contributor Author

Can we add a small CI test to test_dp_attention.py?

I saw the test case of TBO in test_deepep_small. If TBO is ebabled, the two-chunk-overlap will be enabled in most cases. I think the test case of two-chunk-overlap can reuse the TBO, without the need for additional additions. And I have updated the parameter descriptions in server_arguments.md

@House-West House-West requested a review from ch-wan August 4, 2025 06:23
@ch-wan ch-wan merged commit ca47e24 into sgl-project:main Aug 6, 2025
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

narutolhy pushed a commit to narutolhy/sglang that referenced this pull request Aug 17, 2025
MahmoudAshraf97 pushed a commit to MahmoudAshraf97/sglang that referenced this pull request Sep 8, 2025
@programmer-lxj
Copy link
Copy Markdown

@House-West @fzyzcjy @ch-wan I'd like to ask if you achieved performance improvements using TBO on two machines (1P1D)? Or did you achieve results without separating the PD (processing and data) on different machines? On my H20 system, performance decreased both with and without PD separation (1P1D). I've seen some publicly available experimental results online where TBO showed improvements, but they used many machines, such as 4P9D and 4P16D. I also tried your parameters, but the performance decreased significantly. Did I do something wrong? Or does TBO require a significant number of machines to achieve performance improvements? My rough analysis suggests that at least three machines are needed for the overhead and communication costs of TBO to be offset by the computational overlap, and perhaps five machines are needed to see any improvement. Looking forward to your reply!

@House-West
Copy link
Copy Markdown
Contributor Author

@House-West @fzyzcjy @ch-wan I'd like to ask if you achieved performance improvements using TBO on two machines (1P1D)? Or did you achieve results without separating the PD (processing and data) on different machines? On my H20 system, performance decreased both with and without PD separation (1P1D). I've seen some publicly available experimental results online where TBO showed improvements, but they used many machines, such as 4P9D and 4P16D. I also tried your parameters, but the performance decreased significantly. Did I do something wrong? Or does TBO require a significant number of machines to achieve performance improvements? My rough analysis suggests that at least three machines are needed for the overhead and communication costs of TBO to be offset by the computational overlap, and perhaps five machines are needed to see any improvement. Looking forward to your reply!

@programmer-lxj I had tested on h800 previously , and there was no performance improvement with TBO on 1P1D (one machine for prefill, one machine for decode). Because intranode communication uses NVLink, which is relatively fast. In most cases, TBO can bring performance improvements when the communication time more than 30% of end-to-end time, such as 2P9D or 4P9D. For 1P1D , you can try SBO(single batch overlap).

@programmer-lxj
Copy link
Copy Markdown

@House-West Thank you very much! I will try SBO.

@programmer-lxj
Copy link
Copy Markdown

@House-West I tried using SBO, but I found that this parameter can only be added to the P node; adding it to the D node results in an error. I checked the code and found that the _combine_core function in sglang's deeprp.py passes the overlap_args dictionary parameter, but the Buff.low_latency_combine function in the deepep library doesn't have this parameter, which causes the error. Does this mean SBO can only be added to the P node and not the D node? Also, when using 1P1D, adding SBO only to the P node (and not the D node) resulted in a slight performance decrease, not an improvement. I wonder if there are other techniques I should be aware of.
deepep
buffer

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready-to-merge The PR is ready to merge after the CI is green.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants