Skip to content

support more model in piecewise cuda graph#11745

Merged
ispobock merged 5 commits intosgl-project:mainfrom
narutolhy:reshape_for_piecewise_cuda_graph
Oct 24, 2025
Merged

support more model in piecewise cuda graph#11745
ispobock merged 5 commits intosgl-project:mainfrom
narutolhy:reshape_for_piecewise_cuda_graph

Conversation

@narutolhy
Copy link
Copy Markdown
Contributor

Motivation

This PR refines the execution conditions in the CUDA graph runner and relaxes strict output shape checks to improve compatibility with models that have different tensor layouts.

Specifically, some models (e.g., Gemma-3-1B-IT) produce tensors with the same total number of elements but different shapes due to layout or internal reshape differences. The previous strict assert output.shape == ret.shape would incorrectly fail in such cases.
In addition, CUDA graph execution should be allowed when logprobs are returned for the last token only, but not when input logprobs are required.

Modifications

1 piecewise_cuda_graph_runner.py

Updated can_run() logic:
Previously, CUDA graph capture was skipped when forward_batch.return_logprobs was True.
Now, it only skips when input-id logprobs are requested.

for start_len, seq_len in zip(
    forward_batch.extend_logprob_start_lens_cpu,
    forward_batch.extend_seq_lens_cpu
):
    if start_len is not None and start_len < seq_len:
        return True

2 radix_attention.py
Replaced strict shape check:

assert output.shape == ret.shape

with a relaxed element-count check:

assert output.numel() == ret.numel()

Accuracy Tests

Verified correctness on multiple models including:

Qwen3-0.6B

Gemma-3-1B-IT

Outputs are numerically identical to the baseline (within FP16 tolerance).

Logprob return path remains functionally unchanged.

Checklist

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @narutolhy, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly improves the compatibility and flexibility of the CUDA graph runner by addressing limitations related to tensor shape strictness and logprob handling. By relaxing output shape checks to focus on element count and refining the conditions under which CUDA graphs can be utilized with logprobs, it enables support for a wider array of models, such as Gemma-3-1B-IT, ensuring efficient execution without compromising numerical accuracy.

Highlights

  • Relaxed Output Shape Check: The strict assertion for exact tensor shape matching (output.shape == ret.shape) in radix_attention.py has been replaced with a more flexible check ensuring only the total number of elements (output.numel() == ret.numel()) is identical. This allows models with different tensor layouts but the same data size to function correctly.
  • Refined CUDA Graph can_run() Logic: The piecewise_cuda_graph_runner.py now permits CUDA graph execution even when return_logprob is enabled, provided that logprobs are not requested for input IDs (i.e., only for the last token). This enhances compatibility for models that return logprobs selectively.
  • Explicit return_logprob Handling in replay_prepare: The replay_prepare function in piecewise_cuda_graph_runner.py now explicitly sets return_logprob=False when invoking the model runner's forward pass during graph replay. This ensures consistent behavior with the updated can_run logic for logprob handling within CUDA graphs.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces two main changes: relaxing the output shape check in radix_attention to support models with varying tensor layouts, and updating the piecewise_cuda_graph_runner to allow CUDA graph execution when only last-token logprobs are requested.

The change in radix_attention.py to check numel() instead of shape is a good improvement for broader model compatibility.

However, I've found a critical issue in piecewise_cuda_graph_runner.py. The new logic in can_run to check for input logprob requests is inverted and also doesn't handle cases where logprobs are not requested at all, which could lead to a TypeError. This would either cause a crash or result in silently dropping requested input logprobs when using CUDA graphs. Please see my detailed comment for a suggested fix.

Comment on lines +257 to +262
for start_len, seq_len in zip(
forward_batch.extend_logprob_start_lens_cpu,
forward_batch.extend_seq_lens_cpu,
):
if start_len is not None and start_len < seq_len:
return True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There are two issues with this new logic:

  1. forward_batch.extend_logprob_start_lens_cpu can be None, which will cause zip to raise a TypeError. You should check if forward_batch.return_logprob is True before iterating, similar to the old logic.
  2. The logic inside the loop is inverted. If input-id logprobs are requested (start_len < seq_len), CUDA graph execution should be skipped, so can_run should return False. The current code returns True, which would lead to incorrect behavior as the requested logprobs would not be computed in the CUDA graph path.
Suggested change
for start_len, seq_len in zip(
forward_batch.extend_logprob_start_lens_cpu,
forward_batch.extend_seq_lens_cpu,
):
if start_len is not None and start_len < seq_len:
return True
if forward_batch.return_logprob:
for start_len, seq_len in zip(
forward_batch.extend_logprob_start_lens_cpu,
forward_batch.extend_seq_lens_cpu,
):
if start_len is not None and start_len < seq_len:
return False

@narutolhy
Copy link
Copy Markdown
Contributor Author

Hi @Oasis-Git, Thank you for providing the piecewise_cuda_graph feature, which is very helpful. I have made some relaxations on the usage conditions to make it applicable to more scenarios and models. Please take a look. Thank you.

seq_lens_sum=forward_batch.seq_lens_sum,
encoder_lens=forward_batch.encoder_lens,
return_logprob=forward_batch.return_logprob,
return_logprob=False,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For logprobs support, @Oasis-Git is working on that.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I see the plan. I'm only temporarily supporting the return of logprob for non-input tokens. Originally, in the can_run function, only batch logprob was returned, not using piecewise CUDA graphs. However, it seems only input tokens were not supported, so I relaxed the requirements a bit to also support the return of the logprob for the last token in prefill-only scenarios.
Same as line 223, it is all False here and can be changed to True later. I am planning to use it temporarily here. So I set it to false first. Thanks

@ispobock
Copy link
Copy Markdown
Collaborator

@narutolhy Thanks for contribution! We have a slack channel #piecewise-cuda-graph to discuss this feature, welcome to join if you are interested.

@Oasis-Git
Copy link
Copy Markdown
Collaborator

Hi @narutolhy

Thanks for your contribution. Could you please add the following part in your pr:

  1. MMLU Test/Benchmark Output
  2. Unit Test for MMLU in file: https://github.com/sgl-project/sglang/blob/main/test/srt/test_piecewise_cuda_graph.py

@narutolhy
Copy link
Copy Markdown
Contributor Author

Hi @narutolhy

Thanks for your contribution. Could you please add the following part in your pr:

  1. MMLU Test/Benchmark Output
  2. Unit Test for MMLU in file: https://github.com/sgl-project/sglang/blob/main/test/srt/test_piecewise_cuda_graph.py

ok, I will add them

@narutolhy
Copy link
Copy Markdown
Contributor Author

Hi @narutolhy

Thanks for your contribution. Could you please add the following part in your pr:

  1. MMLU Test/Benchmark Output
  2. Unit Test for MMLU in file: https://github.com/sgl-project/sglang/blob/main/test/srt/test_piecewise_cuda_graph.py

Hi @Oasis-Git I have add Unit Test for MMLU in file: https://github.com/sgl-project/sglang/blob/main/test/srt/test_piecewise_cuda_graph.py
I also added tests for the gemma model so that I can test scenarios where the shapes are different but the overall size is the same.

There is no problem when using the unsloth/gemma-3-1b-it model, but an error occurs when using the unsloth/gemma-3-4b-it model:
AssertionError: Input addresses for cudagraphs are different during replay. Expected [140472240193024, 140472239365632, 140472239366144, 140383149409792, 140472239366656, 140466951553024, 140472239364608, 140472239365120], got [140472240193024, 140472239365632, 140472239366144, 140399219879424, 140472239366656, 140466951553024, 140472239364608, 140472239365120]

The value of the fourth pointer is inconsistent, I don't know if you've encountered this before.
That's why I used unsloth/gemma-3-1b-it for the mmlu test. Since it's a small model, the test standards were lowered.

Please review it.
Thank you

@Oasis-Git
Copy link
Copy Markdown
Collaborator

Oasis-Git commented Oct 20, 2025

Hi @narutolhy
Thanks for your contribution. Could you please add the following part in your pr:

  1. MMLU Test/Benchmark Output
  2. Unit Test for MMLU in file: https://github.com/sgl-project/sglang/blob/main/test/srt/test_piecewise_cuda_graph.py

Hi @Oasis-Git I have add Unit Test for MMLU in file: https://github.com/sgl-project/sglang/blob/main/test/srt/test_piecewise_cuda_graph.py I also added tests for the gemma model so that I can test scenarios where the shapes are different but the overall size is the same.

There is no problem when using the unsloth/gemma-3-1b-it model, but an error occurs when using the unsloth/gemma-3-4b-it model: AssertionError: Input addresses for cudagraphs are different during replay. Expected [140472240193024, 140472239365632, 140472239366144, 140383149409792, 140472239366656, 140466951553024, 140472239364608, 140472239365120], got [140472240193024, 140472239365632, 140472239366144, 140399219879424, 140472239366656, 140466951553024, 140472239364608, 140472239365120]

The value of the fourth pointer is inconsistent, I don't know if you've encountered this before. That's why I used unsloth/gemma-3-1b-it for the mmlu test. Since it's a small model, the test standards were lowered.

Please review it. Thank you

I left some comments in slack channel.

@ispobock ispobock merged commit 1801cd1 into sgl-project:main Oct 24, 2025
58 of 68 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants