Skip to content

bugfix[schedule]: Refactor sort method and add related UT#13576

Merged
hnyls2002 merged 28 commits intosgl-project:mainfrom
SeanWeiSean:yuxwei/improveschedularpolicy
Dec 22, 2025
Merged

bugfix[schedule]: Refactor sort method and add related UT#13576
hnyls2002 merged 28 commits intosgl-project:mainfrom
SeanWeiSean:yuxwei/improveschedularpolicy

Conversation

@SeanWeiSean
Copy link
Copy Markdown
Contributor

@SeanWeiSean SeanWeiSean commented Nov 19, 2025

Motivation

Reuse Calc Code and try to fix a token counting issue, build test cases against PrefillAdder Class

Modifications

1. Fix scheduling bugs
Corrects errors in the preemption logic where the wrong request object was used to compute token offsets. This ensures that high-priority requests can preempt low-priority ones accurately, improving scheduler correctness and preventing resource misallocation.

  1. Refactor scheduler logic for clarity and maintainability
    The PR reorganizes scheduling functions, including priority sorting and prefix matching logic. A cleaner structure makes it easier to extend the scheduler in the future for more complex strategies (e.g., multi-level queues, mixed-priority handling).

  2. Strengthen test coverage
    Adds unit tests for edge cases in prefill, preemption, and priority handling, improving code reliability and reducing the risk of regressions in future updates.

Detail the changes made in this pull request.

  1. Introduce priority_sign for calc.
    2. Reuse RadixKey
    3. Reuse _calc_available_and_evictable_tokens
  2. Add UT

Accuracy Tests

Benchmarking and Profiling

Checklist

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @SeanWeiSean, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on improving the schedule_policy.py file by refactoring its scheduling and token management logic and addressing a critical bug. The changes aim to enhance code reusability, optimize performance by reducing object creation, and ensure correct resource allocation during request preemption.

Highlights

  • Unified Priority Scheduling Logic: Introduced a priority_sign to simplify and unify priority-based sorting and preemption logic across different scheduling functions, reducing code duplication and improving readability.
  • Optimized RadixKey Usage: Refactored the _compute_prefix_matches method to create and reuse RadixKey objects, minimizing redundant object creation within loops and enhancing efficiency.
  • Consolidated Token Calculation: Extracted duplicated token availability calculation logic into a new private method _calc_available_and_evictable_tokens, centralizing the logic and improving code maintainability for rem_total_tokens and cur_rem_tokens.
  • Critical Preemption Bug Fix: Corrected a bug in the preempt_to_schedule method where the token offset for preempted requests was incorrectly calculated using the incoming request instead of the running request being preempted, ensuring accurate resource management.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces several refactorings to the scheduler policy, including reusing RadixKey, simplifying priority sorting logic with a priority_sign attribute, and extracting duplicated code into a helper method. It also includes a critical bug fix in preempt_to_schedule where the wrong request was being used to calculate token offsets.

My review identifies two critical bugs in the refactored sorting methods where self was incorrectly added to @staticmethod signatures, which would cause TypeErrors at runtime. I've also pointed out some minor style issues for improved readability. The rest of the changes, including the bug fix and other refactorings, look good.

Comment on lines +252 to +254
def _sort_by_longest_output(
self,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This method is decorated with @staticmethod, but self has been added as the first parameter. A static method does not receive the instance as its first argument. This will cause a TypeError at runtime because the call at line 134 provides 3 arguments, but the method now expects 4. Please remove the self parameter.

Suggested change
def _sort_by_longest_output(
self,
def _sort_by_longest_output(waiting_queue: List[Req],

Comment on lines +273 to +274
def _sort_by_priority_and_fcfs(
waiting_queue: List[Req], schedule_low_priority_values_first: bool
self, waiting_queue: List[Req], priority_sign: int
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Similar to _sort_by_longest_output, this @staticmethod has self as its first parameter, which is incorrect. The call at line 110 passes 2 arguments, but the method now expects 3, which will lead to a TypeError. Please remove the self parameter.

    def _sort_by_priority_and_fcfs(waiting_queue: List[Req], priority_sign: int

)
sorted_running_reqs = sorted(
self.running_batch.reqs,
key=lambda x: (x.priority * (- self.priority_sign), -x.time_stats.wait_queue_entry_time),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's an unnecessary space in (- self.priority_sign). It can be simplified to -self.priority_sign for better readability.

            key=lambda x: (x.priority * -self.priority_sign, -x.time_stats.wait_queue_entry_time),

priority_diff = req.priority - running_req.priority
if server_args.schedule_low_priority_values_first:
priority_diff *= -1
priority_diff = (req.priority - running_req.priority) * (- self.priority_sign)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's an unnecessary space in (- self.priority_sign). It can be simplified to -self.priority_sign for better readability.

            priority_diff = (req.priority - running_req.priority) * -self.priority_sign

for i, running_req in enumerate(self.running_batch.reqs):
if running_req in preemptible_reqs:
self.rem_total_token_offset -= (
self._get_running_request_total_token_offset(req)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This is a great bug fix. Using req instead of running_req was incorrect and would lead to wrong token offset calculations during preemption. Correcting it to running_req ensures the offset is calculated for the actual request being preempted.

@xiezhq-hermann xiezhq-hermann self-assigned this Nov 19, 2025
@SeanWeiSean
Copy link
Copy Markdown
Contributor Author

test_over_preempt_success_low_priority_values_first will fail due to it aginst a bug fixed by #12494

@SeanWeiSean SeanWeiSean changed the title Bug Fix and Refactor for Scheduler.py bugfix[schedule]: Accounting Token Issue and refactor calc method Nov 21, 2025
@harrisonlimh
Copy link
Copy Markdown
Collaborator

harrisonlimh commented Nov 28, 2025

LGTM but would like to have rem_total_tokens() refactoring under hnyls2002@'s radar!

Would you be able to resolve the merge conflict to run the CI beforehand?

@SeanWeiSean
Copy link
Copy Markdown
Contributor Author

Thanks, to be more clear, i remove the token calculation refactor part and only keep the UT and a simple refactor to make this pr more clean and trackable. And now merged latest main.

@SeanWeiSean SeanWeiSean changed the title bugfix[schedule]: Accounting Token Issue and refactor calc method bugfix[schedule]: Refactor sort method and add related UT Dec 3, 2025
@harrisonlimh
Copy link
Copy Markdown
Collaborator

LGTM! Thank you!

Comment thread python/sglang/srt/managers/schedule_policy.py Outdated
@SeanWeiSean
Copy link
Copy Markdown
Contributor Author

SeanWeiSean commented Dec 6, 2025

I can run the failure CI (sglang/test/srt/test_eagle_infer_beta.py) in my local H20.
image

@harrisonlimh
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@JustinTong0323
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@SeanWeiSean
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@hnyls2002 hnyls2002 merged commit 7759716 into sgl-project:main Dec 22, 2025
189 of 199 checks passed
Liwansi added a commit to iforgetmyname/sglang that referenced this pull request Dec 23, 2025
…n_eagle3_dp

* 'main' of https://github.com/sgl-project/sglang: (208 commits)
  MoE: Skip SiLU/GELU activation for masked experts (sgl-project#15539)
  [GLM-ASR] GLM-ASR Support  (sgl-project#15570)
  Improve engine customization interface (sgl-project#15635)
  chore: bump sgl-kernel version to 0.3.20 (sgl-project#15590)
  bugfix[schedule]: Refactor sort method and add related UT (sgl-project#13576)
  Adjust wrong `mtp` meaning introduce by mimo (sgl-project#15632)
  Tiny add back missing router per attempt response metric (sgl-project#15621)
  Fix router gRPC mode launch error caused by async loading (sgl-project#15368)
  [model-gateway] return 503 when all workers are circuit-broken (sgl-project#15611)
  [Diffusion] Support peak memory record in offline generate and serving (sgl-project#15610)
  [VLM] Tiny: Unify VLM environment variables (sgl-project#15572)
  [diffusion] chore: remove default post-denoising dit offload in local mode (sgl-project#15573)
  Tiny enable soft watchdog in CI for stuck without logs (sgl-project#15616)
  Tiny add stuck simulation (sgl-project#15613)
  Support soft watchdog for tokenizer/detokenizer/dp-controller processes (sgl-project#15607)
  Tiny avoid EnvField misuse (sgl-project#15612)
  add decode round robin policy (sgl-project#15164)
  Add glm-4.6-fp8 with/without mtp in nightly ci (sgl-project#15566)
  Adapt fixture-kit to gsm8k mixin (sgl-project#15599)
  [model-gateway] add retry support to OpenAI router chat endpoint (sgl-project#15589)
  ...
@SeanWeiSean SeanWeiSean deleted the yuxwei/improveschedularpolicy branch December 24, 2025 06:03
jiaming1130 pushed a commit to zhuyijie88/sglang that referenced this pull request Dec 25, 2025
…t#13576)

Co-authored-by: Yuxuan Wei <w1300012920@pku.edu.cn>
Co-authored-by: Yuxuan Wei <w1300012920@gmail.com>
Co-authored-by: Yuxuan Wei🚚 <yuxwei@microsoft.com>
YChange01 pushed a commit to YChange01/sglang that referenced this pull request Jan 13, 2026
…t#13576)

Co-authored-by: Yuxuan Wei <w1300012920@pku.edu.cn>
Co-authored-by: Yuxuan Wei <w1300012920@gmail.com>
Co-authored-by: Yuxuan Wei🚚 <yuxwei@microsoft.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants