Skip to content

feat: limit peak memory usage when computing logprobs#6318

Merged
merrymercy merged 4 commits intosgl-project:mainfrom
aftersnow:reduce-peak-mem-usage
Nov 4, 2025
Merged

feat: limit peak memory usage when computing logprobs#6318
merrymercy merged 4 commits intosgl-project:mainfrom
aftersnow:reduce-peak-mem-usage

Conversation

@aftersnow
Copy link
Copy Markdown
Contributor

@aftersnow aftersnow commented May 15, 2025

Motivation

While computing input and output token log probabilities, we frequently encounter CUDA out-of-memory (OOM) errors, even after reducing the --mem-fraction-static to below 0.6.

Like this:

  File "/sglang/python/sglang/srt/layers/logits_processor.py", line 311, in forward
    logits = self._get_logits(pruned_states, lm_head, logits_metadata)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sglang/python/sglang/srt/layers/logits_processor.py", line 461, in _get_logits
    logits = logits[:, : self.config.vocab_size].float()
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 6.05 GiB. GPU 0 has a total capacity of 79.10 GiB of which 5.94 GiB is free. Process 428408 has 73.15 GiB memory in use. Of the allocated memory 65.99 GiB is allocated by PyTorch, with 80.15 MiB allocated in private pools (e.g., CUDA Graphs), and 4.45 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

And this:

  File "/sglang/python/sglang/srt/layers/logits_processor.py", line 375, in forward
    input_logprobs = self.compute_temp_top_p_normalized_logprobs(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sglang/python/sglang/srt/layers/logits_processor.py", line 547, in compute_temp_top_p_normalized_logprobs
    return torch.nn.functional.log_softmax(last_logits, dim=-1)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/functional.py", line 2248, in log_softmax
    ret = input.log_softmax(dim)
          ^^^^^^^^^^^^^^^^^^^^^^
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 7.28 GiB. GPU 0 has a total capacity of 79.10 GiB of which 7.18 GiB is free. Process 214835 has 71.90 GiB memory in use. Process 214837 has 14.00 MiB memory in use. Of the allocated memory 69.32 GiB is allocated by PyTorch, and 126.72 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

The issue arises when the number of rows in the log probabilities is too large, leading to excessive peak memory usage. For instance, with log probabilities shaped [10000, 150000], the peak memory usage exceeds 10000 * 150000 * 4 / 1024 / 1024 = 5722.05 MB, solely due to the code: logits = logits[:, :self.config.vocab_size].float().

To mitigate this, we can divide the log probabilities into smaller chunks, thereby reducing the peak memory usage. The default chunk size is 2048, which means we split the logprobs into chunks of 2048 rows. If the logprobs shape is [10000, 150000], we will split the logprobs into 10000 / 2048 = 5 chunks, so the peak memory can be reduced to 1171.88 MB.

Modifications

In logits_processor.py, we split the logprobs into multiple chunks, compute logprobs chunk by chunk, then gather them. This only takes effect when user enable the logprobs in the generation requests.

Checklist

@zhaochenyang20
Copy link
Copy Markdown
Collaborator

@fzyzcjy tom, could you help on this? Thanks!

@fzyzcjy
Copy link
Copy Markdown
Collaborator

fzyzcjy commented May 27, 2025

well I do not have time recently...

@zhaochenyang20
Copy link
Copy Markdown
Collaborator

well I do not have time recently...

thanks, let me find someone to reivew

@aftersnow aftersnow changed the title fix: limit peak memory usage when computing logprobs [WIP] fix: limit peak memory usage when computing logprobs Jun 10, 2025
@aftersnow aftersnow force-pushed the reduce-peak-mem-usage branch from 3fe18ec to 4576bb7 Compare July 17, 2025 09:43
@aftersnow aftersnow changed the title [WIP] fix: limit peak memory usage when computing logprobs fix: limit peak memory usage when computing logprobs Jul 17, 2025
@aftersnow aftersnow force-pushed the reduce-peak-mem-usage branch from 4576bb7 to 32ce61e Compare July 17, 2025 09:56
@aftersnow aftersnow changed the title fix: limit peak memory usage when computing logprobs feat: limit peak memory usage when computing logprobs Jul 17, 2025
Copy link
Copy Markdown
Collaborator

@zhaochenyang20 zhaochenyang20 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we check the accuracy with huggingface backend in the test?

@merrymercy merrymercy self-assigned this Jul 19, 2025
Copy link
Copy Markdown
Contributor

@merrymercy merrymercy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is indeed a frequent issue we met and the solution makes sense. However, this part of the code is very critical and error-prone. Can you write the following test cases?

  1. Assert the input logprob, input top logprob, model output are the same for the following cases for a prompt with 1k length
    • Set SGLANG_LOGITS_PROCESSER_CHUNK_SIZE to 256
    • Set SGLANG_LOGITS_PROCESSER_CHUNK_SIZE to 2048
    • The old code
  2. A batched version of the above tests where you mix requests that needs input logprobs and does not need input logprobs. similar to this one
    def test_logprob_mixed(self):

If possible, can you refactor the code so that it does not trigger chunk when the length is less than the chunk size? so we can have a code path that is exactly the same as the old code.

@aftersnow
Copy link
Copy Markdown
Contributor Author

aftersnow commented Jul 20, 2025

Shall we check the accuracy with huggingface backend in the test?

Hi, @zhaochenyang20 , thanks for your review. I added an UT to check the accuracy, and it seems that the accuracy of the logprobs differs quite a bit from HF’s. Is this acceptable?

There’s a similar difference even without using this PR.

        prompts = [
            "Hello, my name is",
            "The future of AI is",
            "The president of the United States is",
            "The capital of France is ",
        ]

        sampling_params = {
            "temperature": 1.0,
            "top_p": 1.0,
            "top_k": 10,
            "max_new_tokens": 32,
            "n": 1,
        }

The diffs are:

================
Max diff for token_logprobs: 0.06792879104614258
Max diff for top_logprobs: 0.19390583038330078
Max diff for token_ids_logprobs: 0.10184860229492188
================
Max diff for token_logprobs: 0.0969400405883789
Max diff for top_logprobs: 0.1627025604248047
Max diff for token_ids_logprobs: 0.10129737854003906
================
Max diff for token_logprobs: 0.07500886917114258
Max diff for top_logprobs: 0.17636871337890625
Max diff for token_ids_logprobs: 0.1154632568359375
================
Max diff for token_logprobs: 0.08561110496520996
Max diff for top_logprobs: 0.18017578125
Max diff for token_ids_logprobs: 0.13154220581054688

The rtol must be set to > 0.2 in order to pass the assertion of torch.allclose(). Maybe my calculation of the HF's log probs is not correct: I get the logits from HF auto model then apply the log_softmax() on it, because the HF seems has no logprobs output. It shouldn’t need to handle temperature or topp after that because I set temperature=1.0 and topp=1.0.

self.assertTrue(
    torch.allclose(
        hf_token_logprobs, srt_token_logprobs, atol=0, rtol=rtol
     )
)

self.assertTrue(
    torch.allclose(
        hf_top_logprobs,
        srt_top_logprobs,
        atol=0,
        rtol=rtol,
    )
)

self.assertTrue(
    torch.allclose(
        hf_token_ids_logprobs, srt_token_ids_logprobs, atol=0, rtol=rtol
     )
)

The test code is here: https://github.com/sgl-project/sglang/pull/6318/files#diff-74d422b689aa0e95e1425839e9d1f6dede379dd885199d718a4d16eff21ff714

@aftersnow
Copy link
Copy Markdown
Contributor Author

Shall we check the accuracy with huggingface backend in the test?

sure, wip~

@aftersnow
Copy link
Copy Markdown
Contributor Author

This is indeed a frequent issue we met and the solution makes sense. However, this part of the code is very critical and error-prone. Can you write the following test cases?

  1. Assert the input logprob, input top logprob, model output are the same for the following cases for a prompt with 1k length

    • Set SGLANG_LOGITS_PROCESSER_CHUNK_SIZE to 256
    • Set SGLANG_LOGITS_PROCESSER_CHUNK_SIZE to 2048
    • The old code
  2. A batched version of the above tests where you mix requests that needs input logprobs and does not need input logprobs. similar to this one

    def test_logprob_mixed(self):

If possible, can you refactor the code so that it does not trigger chunk when the length is less than the chunk size? so we can have a code path that is exactly the same as the old code.

sure~

@aftersnow aftersnow force-pushed the reduce-peak-mem-usage branch from 576550d to 780f975 Compare July 20, 2025 16:15
@zhaochenyang20
Copy link
Copy Markdown
Collaborator

@aftersnow Hey, the difference looks acceptable to me. I will ping lianmin to review also.

Comment thread test/srt/test_logprobs.py Outdated
@aftersnow aftersnow changed the title feat: limit peak memory usage when computing logprobs [WIP] feat: limit peak memory usage when computing logprobs Jul 21, 2025
@aftersnow aftersnow requested a review from Edwardf0t1 as a code owner August 10, 2025 17:29
@zhaochenyang20
Copy link
Copy Markdown
Collaborator

test case is too easy that do not touch the boundary of the modified codes.

Comment thread test/srt/run_suite.py Outdated
Comment on lines +114 to +116
TestFile("test_w8a8_quantization.py", 46),
TestFile("test_reasoning_parser.py", 5),
TestFile("test_hybrid_attn_backend.py", 100),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do not rerun other test cases. It is probably from other merge conflicts

Comment thread test/srt/test_chunked_logprobs.py Outdated
@aftersnow
Copy link
Copy Markdown
Contributor Author

test case is too easy that do not touch the boundary of the modified codes.

Yes indeed. Trying to add some tests from @merrymercy 's comments.

@aftersnow
Copy link
Copy Markdown
Contributor Author

test case is too easy that do not touch the boundary of the modified codes.

Yes indeed. Trying to add some tests from @merrymercy 's comments.

Actually, the tests were added by the author of this PR. A big thanks to them.

@aftersnow aftersnow force-pushed the reduce-peak-mem-usage branch 2 times, most recently from a593e56 to deefcc7 Compare September 18, 2025 06:03
Comment thread test/srt/test_logprobs.py Outdated
@aftersnow aftersnow force-pushed the reduce-peak-mem-usage branch from c29f8e6 to b08b624 Compare September 21, 2025 04:11
Comment thread python/sglang/srt/layers/logits_processor.py Outdated
@aftersnow aftersnow force-pushed the reduce-peak-mem-usage branch 2 times, most recently from 2d95803 to 0feac55 Compare September 27, 2025 17:45
@aftersnow aftersnow force-pushed the reduce-peak-mem-usage branch 2 times, most recently from af98376 to 71cc92d Compare October 26, 2025 19:01
@zhaochenyang20
Copy link
Copy Markdown
Collaborator

After rebase and test locally following this test:

#10994

python test/srt/test_logprobs.py gen
python test/srt/test_logprobs.py test

6318 is all set.

@aftersnow aftersnow force-pushed the reduce-peak-mem-usage branch from 71cc92d to aef68d6 Compare October 31, 2025 04:28
@aftersnow
Copy link
Copy Markdown
Contributor Author

Functional Testing Steps:

  1. Checkout the main branch of sglang:

    git checkout main && git pull && pip install -e "python"
  2. Run the test to get the baseline:

    python test/srt/test_logprobs.py gen
  3. Checkout PR 6318:

    gh pr checkout 123

    If you are not using GitHub CLI, you can use the following commands:

    git remote add aftersnow https://github.com/aftersnow/sglang
    git fetch aftersnow
    git checkout -b reduce-peak-mem-usage aftersnow/reduce-peak-mem-usage
  4. Run the test:

    python test/srt/test_logprobs.py test

    The test covers various combinations of chunked logprobs conditions.

Signed-off-by: Zhao Chen <zhaochen.zju@gmail.com>
@aftersnow aftersnow force-pushed the reduce-peak-mem-usage branch from aef68d6 to d9c621d Compare October 31, 2025 04:57
@aftersnow
Copy link
Copy Markdown
Contributor Author

aftersnow commented Oct 31, 2025

Throughput Test

Test Environment

  • Single GPU: H200
  • Baseline: sglang main branch 069e490

Test Steps

Since sglang.bench_offline_throughput does not support --return-logprobs, we can use this branch for testing:

git remote add aftersnow https://github.com/aftersnow/sglang
git fetch aftersnow
git checkout -b reduce-peak-mem-usage-throughput-test aftersnow/reduce-peak-mem-usage-throughput-test
git remote remove aftersnow

Test Result 1

Disable chunked logprobs:

export SGLANG_ENABLE_LOGITS_PROCESSER_CHUNK=False

Run the test:

python3 -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --num-prompts 10000 --return-logprob --logprob-start-len 0 --dataset-name random --random-input-len 512 --random-output-len 128

Test result: CUDA Out of Memory (OOM) error

[2025-10-31 10:36:02] Scheduler hit an exception: Traceback (most recent call last):
  ...
  torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 2.56 GiB. GPU 0 has a total capacity of 139.81 GiB of which 2.22 GiB is free. Including non-PyTorch memory, this process has 137.59 GiB memory in use. Of the allocated memory 134.28 GiB is allocated by PyTorch, with 210.00 MiB allocated in private pools (e.g., CUDA Graphs), and 2.39 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

[2025-10-31 10:36:02] SIGQUIT received. signum=None, frame=None. It usually means one child failed.
Killed

Enable chunked logprobs (default chunk size):

export SGLANG_ENABLE_LOGITS_PROCESSER_CHUNK=True
export SGLANG_LOGITS_PROCESSER_CHUNK_SIZE=2048

python3 -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --num-prompts 10000 --return-logprob --logprob-start-len 0 --dataset-name random --random-input-len 512 --random-output-len 128

Offline Throughput Benchmark Results:

Backend:                                 engine    
Successful requests:                     10000     
Benchmark duration (s):                  102.01    
Total input tokens:                      2574684   
Total generated tokens:                  644568    
Last generation throughput (tok/s):      13006.15  
Request throughput (req/s):              98.03     
Input token throughput (tok/s):          25238.92  
Output token throughput (tok/s):         6318.52   
Total token throughput (tok/s):          31557.45  

Test Result 2

Reduce memory usage to avoid OOM (reduce --mem-fraction-static to 0.8):

export SGLANG_ENABLE_LOGITS_PROCESSER_CHUNK=False

python3 -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --num-prompts 10000 --return-logprob --logprob-start-len 0 --dataset-name random --random-input-len 512 --random-output-len 128 --mem-fraction-static 0.8

Test result with chunked logprobs disabled:

====== Offline Throughput Benchmark Result =======
Backend:                                 engine    
Successful requests:                     10000     
Benchmark duration (s):                  100.74    
Total input tokens:                      2574684   
Total generated tokens:                  644568    
Last generation throughput (tok/s):      19272.81  
Request throughput (req/s):              99.26     
Input token throughput (tok/s):          25556.84  
Output token throughput (tok/s):         6398.11   
Total token throughput (tok/s):          31954.95  
==================================================

Test result with chunked logprobs enabled:

export SGLANG_ENABLE_LOGITS_PROCESSER_CHUNK=True
export SGLANG_LOGITS_PROCESSER_CHUNK_SIZE=2048

python3 -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --num-prompts 10000 --return-logprob --logprob-start-len 0 --dataset-name random --random-input-len 512 --random-output-len 128 --mem-fraction-static 0.8
====== Offline Throughput Benchmark Result =======
Backend:                                 engine    
Successful requests:                     10000     
Benchmark duration (s):                  101.46    
Total input tokens:                      2574684   
Total generated tokens:                  644568    
Last generation throughput (tok/s):      18983.11  
Request throughput (req/s):              98.56     
Input token throughput (tok/s):          25375.96  
Output token throughput (tok/s):         6352.83   
Total token throughput (tok/s):          31728.80  
==================================================

@merrymercy merrymercy merged commit d5fa019 into sgl-project:main Nov 4, 2025
16 of 86 checks passed
@merrymercy
Copy link
Copy Markdown
Contributor

@aftersnow @zhaochenyang20 Really great work! The features and tests are solid. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.