feat: limit peak memory usage when computing logprobs#6318
feat: limit peak memory usage when computing logprobs#6318merrymercy merged 4 commits intosgl-project:mainfrom
Conversation
|
@fzyzcjy tom, could you help on this? Thanks! |
|
well I do not have time recently... |
thanks, let me find someone to reivew |
3fe18ec to
4576bb7
Compare
4576bb7 to
32ce61e
Compare
zhaochenyang20
left a comment
There was a problem hiding this comment.
Shall we check the accuracy with huggingface backend in the test?
There was a problem hiding this comment.
This is indeed a frequent issue we met and the solution makes sense. However, this part of the code is very critical and error-prone. Can you write the following test cases?
- Assert the input logprob, input top logprob, model output are the same for the following cases for a prompt with 1k length
- Set
SGLANG_LOGITS_PROCESSER_CHUNK_SIZEto 256 - Set
SGLANG_LOGITS_PROCESSER_CHUNK_SIZEto 2048 - The old code
- Set
- A batched version of the above tests where you mix requests that needs input logprobs and does not need input logprobs. similar to this one
sglang/test/srt/test_srt_endpoint.py
Line 245 in bb0e8a3
If possible, can you refactor the code so that it does not trigger chunk when the length is less than the chunk size? so we can have a code path that is exactly the same as the old code.
Hi, @zhaochenyang20 , thanks for your review. I added an UT to check the accuracy, and it seems that the accuracy of the logprobs differs quite a bit from HF’s. Is this acceptable? There’s a similar difference even without using this PR. The diffs are: The The test code is here: https://github.com/sgl-project/sglang/pull/6318/files#diff-74d422b689aa0e95e1425839e9d1f6dede379dd885199d718a4d16eff21ff714 |
sure, wip~ |
sure~ |
576550d to
780f975
Compare
|
@aftersnow Hey, the difference looks acceptable to me. I will ping lianmin to review also. |
|
test case is too easy that do not touch the boundary of the modified codes. |
| TestFile("test_w8a8_quantization.py", 46), | ||
| TestFile("test_reasoning_parser.py", 5), | ||
| TestFile("test_hybrid_attn_backend.py", 100), |
There was a problem hiding this comment.
do not rerun other test cases. It is probably from other merge conflicts
Yes indeed. Trying to add some tests from @merrymercy 's comments. |
0a64c0c to
78d1ce3
Compare
Actually, the tests were added by the author of this PR. A big thanks to them. |
a593e56 to
deefcc7
Compare
c29f8e6 to
b08b624
Compare
2d95803 to
0feac55
Compare
af98376 to
71cc92d
Compare
|
After rebase and test locally following this test: python test/srt/test_logprobs.py gen 6318 is all set. |
71cc92d to
aef68d6
Compare
|
Functional Testing Steps:
|
Signed-off-by: Zhao Chen <zhaochen.zju@gmail.com>
aef68d6 to
d9c621d
Compare
Throughput TestTest Environment
Test StepsSince git remote add aftersnow https://github.com/aftersnow/sglang
git fetch aftersnow
git checkout -b reduce-peak-mem-usage-throughput-test aftersnow/reduce-peak-mem-usage-throughput-test
git remote remove aftersnowTest Result 1Disable chunked logprobs: export SGLANG_ENABLE_LOGITS_PROCESSER_CHUNK=FalseRun the test: python3 -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --num-prompts 10000 --return-logprob --logprob-start-len 0 --dataset-name random --random-input-len 512 --random-output-len 128Test result: CUDA Out of Memory (OOM) error Enable chunked logprobs (default chunk size): export SGLANG_ENABLE_LOGITS_PROCESSER_CHUNK=True
export SGLANG_LOGITS_PROCESSER_CHUNK_SIZE=2048
python3 -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --num-prompts 10000 --return-logprob --logprob-start-len 0 --dataset-name random --random-input-len 512 --random-output-len 128Offline Throughput Benchmark Results: Test Result 2Reduce memory usage to avoid OOM (reduce export SGLANG_ENABLE_LOGITS_PROCESSER_CHUNK=False
python3 -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --num-prompts 10000 --return-logprob --logprob-start-len 0 --dataset-name random --random-input-len 512 --random-output-len 128 --mem-fraction-static 0.8Test result with chunked logprobs disabled: Test result with chunked logprobs enabled: export SGLANG_ENABLE_LOGITS_PROCESSER_CHUNK=True
export SGLANG_LOGITS_PROCESSER_CHUNK_SIZE=2048
python3 -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --num-prompts 10000 --return-logprob --logprob-start-len 0 --dataset-name random --random-input-len 512 --random-output-len 128 --mem-fraction-static 0.8 |
|
@aftersnow @zhaochenyang20 Really great work! The features and tests are solid. Thanks! |
Motivation
While computing input and output token log probabilities, we frequently encounter CUDA out-of-memory (OOM) errors, even after reducing the --mem-fraction-static to below 0.6.
Like this:
And this:
The issue arises when the number of rows in the log probabilities is too large, leading to excessive peak memory usage. For instance, with log probabilities shaped [10000, 150000], the peak memory usage exceeds 10000 * 150000 * 4 / 1024 / 1024 = 5722.05 MB, solely due to the code:
logits = logits[:, :self.config.vocab_size].float().To mitigate this, we can divide the log probabilities into smaller chunks, thereby reducing the peak memory usage. The default chunk size is 2048, which means we split the logprobs into chunks of 2048 rows. If the logprobs shape is [10000, 150000], we will split the logprobs into 10000 / 2048 = 5 chunks, so the peak memory can be reduced to 1171.88 MB.
Modifications
In
logits_processor.py, we split the logprobs into multiple chunks, compute logprobs chunk by chunk, then gather them. This only takes effect when user enable the logprobs in the generation requests.Checklist