Skip to content

[PP] Fix recv tensor dict potential race condition#20341

Merged
ShangmingCai merged 2 commits intomainfrom
fix_pp_key_error
Mar 14, 2026
Merged

[PP] Fix recv tensor dict potential race condition#20341
ShangmingCai merged 2 commits intomainfrom
fix_pp_key_error

Conversation

@ShangmingCai
Copy link
Copy Markdown
Collaborator

@ShangmingCai ShangmingCai commented Mar 11, 2026

Motivation

This PR adds a msg type to distinguish proxy tensor dict receiving and output tensor dict receiving without harming performance, which might help when PP communication encounters a race condition in rare cases.

related issue:
#19686 #19750

Modifications

Accuracy Tests

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

Signed-off-by: Shangming Cai <csmthu@gmail.com>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a potential race condition in the pipeline parallelism (PP) communication by introducing a robust typed messaging system for tensor dictionaries. It ensures that inter-stage tensor exchanges are correctly ordered and processed, even if messages arrive out of sequence, by categorizing and temporarily storing unexpected messages. This enhancement significantly improves the reliability and stability of the PP scheduler's communication mechanisms.

Highlights

  • Typed Messaging for Tensor Dictionaries: Introduced a msg_kind parameter to _pp_send_dict_to_next_stage to explicitly type tensor dictionaries being sent between pipeline stages, preventing potential conflicts.
  • Robust Tensor Dictionary Reception: Implemented a new _pp_recv_typed_dict function that can handle out-of-order messages by stashing unexpected message types in a defaultdict(deque) inbox until the expected message kind is received.
  • Integration of Typed Communication: Updated existing send and receive calls, specifically _pp_recv_proxy_tensors, _pp_recv_dict_from_prev_stage, and various calls within event loops, to utilize the new typed messaging system.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/managers/scheduler_pp_mixin.py
    • Imported defaultdict from the collections module.
    • Added msg_kind='proxy' to _pp_send_dict_to_next_stage calls within event_loop_pp, event_loop_pp_disagg_prefill, and event_loop_pp_disagg_decode.
    • Initialized _pp_tensor_dict_inbox as a defaultdict(deque) to manage typed incoming tensor dictionaries.
    • Modified _pp_send_dict_to_next_stage to accept a msg_kind parameter and embed it within the tensor dictionary, also adding a warning for untyped messages.
    • Added a new method _pp_recv_typed_dict to receive tensor dictionaries, demultiplexing them by msg_kind and stashing out-of-order messages.
    • Updated _pp_recv_proxy_tensors to use _pp_recv_typed_dict with expected_kind='proxy'.
    • Updated _pp_recv_dict_from_prev_stage to use _pp_recv_typed_dict with expected_kind='output'.
    • Added msg_kind='output' to _pp_send_dict_to_next_stage calls within _pp_send_output_to_next_stage.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a typed messaging system for pipeline parallelism to address a potential race condition when receiving tensor dictionaries, tagging messages with msg_kind and handling out-of-order messages via _pp_recv_typed_dict. However, a medium-severity vulnerability was identified in the message stashing logic due to a lack of proper resource constraints and error handling, which could lead to Denial of Service (DoS) through memory exhaustion or infinite loops. Additionally, consider avoiding in-place modification of the tensor dictionary being sent to improve code safety.

Comment thread python/sglang/srt/managers/scheduler_pp_mixin.py
Comment thread python/sglang/srt/managers/scheduler_pp_mixin.py Outdated
Signed-off-by: Shangming Cai <csmthu@gmail.com>
@ShangmingCai
Copy link
Copy Markdown
Collaborator Author

/rerun-stage stage-c-test-8-gpu-h20

@github-actions
Copy link
Copy Markdown
Contributor

✅ Triggered stage-c-test-8-gpu-h20 to run independently (skipping dependencies).

@github-actions
Copy link
Copy Markdown
Contributor

🔗 View workflow run

@chris0927
Copy link
Copy Markdown

I test the code ,error:
File "/home/jovyan/LLM/wangn/glm-5-server/sglang-main/12.4/sglang/python/sglang/srt/managers/scheduler.py", line 3155, in run_scheduler_process
scheduler.event_loop_pp()
File "/home/jovyan/.conda/envs/sglang-glm5-124-sync1/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/LLM/wangn/glm-5-server/sglang-main/12.4/sglang/python/sglang/srt/managers/scheduler_pp_mixin.py", line 155, in event_loop_pp
self.self_check_during_idle()
File "/home/jovyan/LLM/wangn/glm-5-server/sglang-main/12.4/sglang/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py", line 332, in self_check_during_idle
self.check_memory()
File "/home/jovyan/LLM/wangn/glm-5-server/sglang-main/12.4/sglang/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py", line 244, in check_memory
raise_error_or_warn(
File "/home/jovyan/LLM/wangn/glm-5-server/sglang-main/12.4/sglang/python/sglang/srt/utils/common.py", line 4082, in raise_error_or_warn
raise ValueError(message)
ValueError: token_to_kv_pool_allocator memory leak detected! self.max_total_num_tokens=454912, available_size=454848, evictable_size=0, protected_size=0

@ShangmingCai
Copy link
Copy Markdown
Collaborator Author

ShangmingCai commented Mar 11, 2026

@chris0927 can you share your start commands? I cannot reproduce the bug, the key error one and the memory leak one.

Which release version of SGLang? Main branch? can you share the commit hash as well?

@chris0927
Copy link
Copy Markdown

chris0927 commented Mar 11, 2026

@ShangmingCai 0.5.9,only modify the file:sglang-0.5.9/python/sglang/srt/managers/scheduler_pp_mixin.py, run cmd:
rank 0 : python -m sglang.launch_server --dist-init-addr 192.168.9.85:9999 --nnodes 2 --node-rank 0 --model-path /home/GLM-5-FP8 --served-model-name glm-5 --tp 8 --pp-size 2 --trust-remote-code --tool-call-parser glm47 --reasoning-parser glm45 --chunked-prefill-size 2048 --max-running-requests 16 --context-length 202752 --host 0.0.0.0 --port 8000 --attention-backend flashinfer

rank 1: python -m sglang.launch_server --dist-init-addr 192.168.9.85:9999 --nnodes 2 --node-rank 1 --model-path /home/GLM-5-FP8 --served-model-name glm-5 --tp 8 --pp-size 2 --trust-remote-code --tool-call-parser glm47 --reasoning-parser glm45 --chunked-prefill-size 2048 --max-running-requests 16 --context-length 202752 --host 0.0.0.0 --port 8000 --attention-backend flashinfer

@chris0927
Copy link
Copy Markdown

chris0927 commented Mar 11, 2026

@ShangmingCai i pull the file :https://github.com/sgl-project/sglang/blob/aeda4a72a60b3d71c90e143fd7d2c48739012f85/python/sglang/srt/managers/scheduler_pp_mixin.py,
replaced the file of 0.5.9,thhe program is running properly now.

@ShangmingCai
Copy link
Copy Markdown
Collaborator Author

ShangmingCai commented Mar 11, 2026

@ShangmingCai i pull the file :https://github.com/sgl-project/sglang/blob/aeda4a72a60b3d71c90e143fd7d2c48739012f85/python/sglang/srt/managers/scheduler_pp_mixin.py, replaced the file of 0.5.9,thhe program is running properly now.

@chris0927 Thx for the verification. You can also try running some stress tests on this version. If the key error won't happen again, we can merge this PR into main as soon as possible.

I will run more tests tomorrow.

@chris0927
Copy link
Copy Markdown

@ShangmingCai After testing for 2 days, there were almost no errors about 'key error', but this error appeared:
[2026-03-13 05:37:11] Received sigquit from a child process. It usually means the child failed.
--- Logging error ---
[2026-03-13 05:37:11 PP1 TP6] Scheduler hit an exception: Traceback (most recent call last):
File "/home/sglang-0.5.9/python/sglang/srt/managers/scheduler.py", line 3167, in run_scheduler_process
scheduler.event_loop_pp()
File "/home/jovyan/.conda/envs/sglang-glm5-124-sync1/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/sglang-0.5.9/python/sglang/srt/managers/scheduler_pp_mixin.py", line 108, in event_loop_pp
result, self.launch_event = self._pp_launch_batch(
^^^^^^^^^^^^^^^^^^^^^^
File "/home/sglang-0.5.9/python/sglang/srt/managers/scheduler_pp_mixin.py", line 1111, in _pp_launch_batch
result = self.run_batch(self.cur_batch, pp_proxy_tensors)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/sglang-0.5.9/python/sglang/srt/managers/scheduler.py", line 2368, in run_batch
batch_result = self.model_worker.forward_batch_generation(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/sglang-0.5.9/python/sglang/srt/managers/tp_worker.py", line 456, in forward_batch_generation
out = self.model_runner.forward(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/sglang-0.5.9/python/sglang/srt/model_executor/model_runner.py", line 2401, in forward
output = self._forward_raw(
^^^^^^^^^^^^^^^^^^
File "/home/sglang-0.5.9/python/sglang/srt/model_executor/model_runner.py", line 2500, in _forward_raw
ret, can_run_graph = self.forward_extend(
^^^^^^^^^^^^^^^^^^^^
File "/home/sglang-0.5.9/python/sglang/srt/model_executor/model_runner.py", line 2338, in forward_extend
self.model.forward(
File "/home/jovyan/.conda/envs/sglang-glm5-124-sync1/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/sglang-0.5.9/python/sglang/srt/models/deepseek_v2.py", line 2919, in forward
hidden_states = self.model(
^^^^^^^^^^^
File "/home/jovyan/.conda/envs/sglang-glm5-124-sync1/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/.conda/envs/sglang-glm5-124-sync1/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/sglang-0.5.9/python/sglang/srt/models/deepseek_v2.py", line 2730, in forward
hidden_states, residual = layer(
^^^^^^
File "/home/jovyan/.conda/envs/sglang-glm5-124-sync1/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/.conda/envs/sglang-glm5-124-sync1/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/sglang-0.5.9/python/sglang/srt/models/deepseek_v2.py", line 2395, in forward
hidden_states = self.self_attn(
^^^^^^^^^^^^^^^
File "/home/jovyan/.conda/envs/sglang-glm5-124-sync1/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/.conda/envs/sglang-glm5-124-sync1/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/sglang-0.5.9/python/sglang/srt/models/deepseek_v2.py", line 1366, in forward
s = self.forward_prepare(
^^^^^^^^^^^^^^^^^^^^^
File "/home/sglang-0.5.9/python/sglang/srt/models/deepseek_v2.py", line 1416, in forward_prepare
inner_state = self.forward_normal_one_shot_prepare(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/sglang-0.5.9/python/sglang/srt/models/deepseek_common/attention_forward_methods/forward_mha.py", line 322, in forward_normal_one_shot_prepare
return self.forward_normal_prepare(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/sglang-0.5.9/python/sglang/srt/models/deepseek_common/attention_forward_methods/forward_mha.py", line 210, in forward_normal_prepare
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/.conda/envs/sglang-glm5-124-sync1/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jovyan/.conda/envs/sglang-glm5-124-sync1/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/sglang-0.5.9/python/sglang/srt/layers/rotary_embedding.py", line 283, in forward_native
query = query.view(num_tokens, -1, self.head_size)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: shape '[2048, -1, 64]' is invalid for input of size 512

@Zhangmj0621
Copy link
Copy Markdown

@ShangmingCai After testing for 2 days, there were almost no errors about 'key error', but this error appeared: [2026-03-13 05:37:11] Received sigquit from a child process. It usually means the child failed. --- Logging error --- [2026-03-13 05:37:11 PP1 TP6] Scheduler hit an exception: Traceback (most recent call last): File "/home/sglang-0.5.9/python/sglang/srt/managers/scheduler.py", line 3167, in run_scheduler_process scheduler.event_loop_pp() File "/home/jovyan/.conda/envs/sglang-glm5-124-sync1/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/home/sglang-0.5.9/python/sglang/srt/managers/scheduler_pp_mixin.py", line 108, in event_loop_pp result, self.launch_event = self._pp_launch_batch( ^^^^^^^^^^^^^^^^^^^^^^ File "/home/sglang-0.5.9/python/sglang/srt/managers/scheduler_pp_mixin.py", line 1111, in _pp_launch_batch result = self.run_batch(self.cur_batch, pp_proxy_tensors) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/sglang-0.5.9/python/sglang/srt/managers/scheduler.py", line 2368, in run_batch batch_result = self.model_worker.forward_batch_generation( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/sglang-0.5.9/python/sglang/srt/managers/tp_worker.py", line 456, in forward_batch_generation out = self.model_runner.forward( ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/sglang-0.5.9/python/sglang/srt/model_executor/model_runner.py", line 2401, in forward output = self._forward_raw( ^^^^^^^^^^^^^^^^^^ File "/home/sglang-0.5.9/python/sglang/srt/model_executor/model_runner.py", line 2500, in _forward_raw ret, can_run_graph = self.forward_extend( ^^^^^^^^^^^^^^^^^^^^ File "/home/sglang-0.5.9/python/sglang/srt/model_executor/model_runner.py", line 2338, in forward_extend self.model.forward( File "/home/jovyan/.conda/envs/sglang-glm5-124-sync1/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/home/sglang-0.5.9/python/sglang/srt/models/deepseek_v2.py", line 2919, in forward hidden_states = self.model( ^^^^^^^^^^^ File "/home/jovyan/.conda/envs/sglang-glm5-124-sync1/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/jovyan/.conda/envs/sglang-glm5-124-sync1/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/sglang-0.5.9/python/sglang/srt/models/deepseek_v2.py", line 2730, in forward hidden_states, residual = layer( ^^^^^^ File "/home/jovyan/.conda/envs/sglang-glm5-124-sync1/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/jovyan/.conda/envs/sglang-glm5-124-sync1/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/sglang-0.5.9/python/sglang/srt/models/deepseek_v2.py", line 2395, in forward hidden_states = self.self_attn( ^^^^^^^^^^^^^^^ File "/home/jovyan/.conda/envs/sglang-glm5-124-sync1/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/jovyan/.conda/envs/sglang-glm5-124-sync1/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/sglang-0.5.9/python/sglang/srt/models/deepseek_v2.py", line 1366, in forward s = self.forward_prepare( ^^^^^^^^^^^^^^^^^^^^^ File "/home/sglang-0.5.9/python/sglang/srt/models/deepseek_v2.py", line 1416, in forward_prepare inner_state = self.forward_normal_one_shot_prepare( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/sglang-0.5.9/python/sglang/srt/models/deepseek_common/attention_forward_methods/forward_mha.py", line 322, in forward_normal_one_shot_prepare return self.forward_normal_prepare( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/sglang-0.5.9/python/sglang/srt/models/deepseek_common/attention_forward_methods/forward_mha.py", line 210, in forward_normal_prepare q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/jovyan/.conda/envs/sglang-glm5-124-sync1/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/jovyan/.conda/envs/sglang-glm5-124-sync1/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/sglang-0.5.9/python/sglang/srt/layers/rotary_embedding.py", line 283, in forward_native query = query.view(num_tokens, -1, self.head_size) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: shape '[2048, -1, 64]' is invalid for input of size 512

I have same error, this shape mismatch issue can be reproduced with multi-turn hicache benchmark at about 700/800 requests.

@Zhangmj0621
Copy link
Copy Markdown

Zhangmj0621 commented Mar 13, 2026

@ShangmingCai Could you look at this shape mismatch issue or if I start a new issue? If I change pp-size from 2 to 1, this issue disappears, so I think this is a issue related to PP.

@Zhangmj0621
Copy link
Copy Markdown

Hi, I have issue here.

@ShangmingCai
Copy link
Copy Markdown
Collaborator Author

@ShangmingCai Could you look at this shape mismatch issue or if I start a new issue? If I change pp-size from 2 to 1, this issue disappears, so I think this is a issue related to PP.

This should be another issue, unless it is related to recv_tensor_dict.

@ShangmingCai
Copy link
Copy Markdown
Collaborator Author

/tag-and-rerun-ci

@ShangmingCai
Copy link
Copy Markdown
Collaborator Author

/rerun-stage stage-c-test-4-gpu-h100

@github-actions
Copy link
Copy Markdown
Contributor

✅ Triggered stage-c-test-4-gpu-h100 to run independently (skipping dependencies).

@github-actions
Copy link
Copy Markdown
Contributor

🔗 View workflow run

@ShangmingCai
Copy link
Copy Markdown
Collaborator Author

image image

Related CI has passed. Since we only modified the pp_mixin file, and I have invited many users to try out this PR, and both of them said they have never seen a key error again when using PP, this PR should be ready to merge. Ping me if it didn't work or if you encounter a new issue with PP.

@ShangmingCai ShangmingCai merged commit 99a3b25 into main Mar 14, 2026
232 of 258 checks passed
@ShangmingCai ShangmingCai deleted the fix_pp_key_error branch March 14, 2026 05:35
whybeyoung pushed a commit to whybeyoung/sglang that referenced this pull request Mar 14, 2026
Signed-off-by: Shangming Cai <csmthu@gmail.com>
whybeyoung pushed a commit to whybeyoung/sglang that referenced this pull request Mar 14, 2026
Signed-off-by: Shangming Cai <csmthu@gmail.com>
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Mar 15, 2026
Signed-off-by: Shangming Cai <csmthu@gmail.com>
Wangzheee pushed a commit to Wangzheee/sglang that referenced this pull request Mar 21, 2026
Signed-off-by: Shangming Cai <csmthu@gmail.com>
0-693 pushed a commit to 0-693/sglang that referenced this pull request Mar 25, 2026
Signed-off-by: Shangming Cai <csmthu@gmail.com>
JustinTong0323 pushed a commit to JustinTong0323/sglang that referenced this pull request Apr 7, 2026
Signed-off-by: Shangming Cai <csmthu@gmail.com>
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Apr 22, 2026
Signed-off-by: Shangming Cai <csmthu@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants