Skip to content

[Fix] fix enable_deterministic_inference failed with enable_dp_attention#11023

Open
GuoweiWangU wants to merge 3 commits intosgl-project:mainfrom
GuoweiWangU:fix_dp_deterministic_inference
Open

[Fix] fix enable_deterministic_inference failed with enable_dp_attention#11023
GuoweiWangU wants to merge 3 commits intosgl-project:mainfrom
GuoweiWangU:fix_dp_deterministic_inference

Conversation

@GuoweiWangU
Copy link
Copy Markdown
Contributor

@GuoweiWangU GuoweiWangU commented Sep 28, 2025

Motivation

--enable-deterministic-inference failed to get deterministic results when run Moe models with --enable-dp-attention.

  1. Test results of python3 -m sglang.test.test_deterministic --test-mode mixed

Before this pr:

Prompt 1: total samples: 469, Unique samples: 3
Prompt 2: total samples: 565, Unique samples: 1
Long prompt: total samples: 241, Unique samples: 5

After fix:

Prompt 1: total samples: 481, Unique samples: 1
Prompt 2: total samples: 564, Unique samples: 1
Long prompt: total samples: 230, Unique samples: 1

  1. Test results of python3 -m sglang.test.test_deterministic --test-mode prefix

In the current version, test prefix with enable_dp_attention triggers a 400 Bad Request error due to flush_cache in send_prefix, ultimately leading to memory leaks ValueError: token_to_kv_pool_allocator memory leak detected!. The specific cause is unknown and unrelated to this pull request. The following results are only obtained after commenting out flush_cache.

Before this pr:

Prompt 0 with prefix length 1: total samples: 296, Unique samples: 2
Prompt 1 with prefix length 511: total samples: 297, Unique samples: 2
Prompt 2 with prefix length 2048: total samples: 333, Unique samples: 3
Prompt 3 with prefix length 4097: total samples: 349, Unique samples: 5

After fix:

Prompt 0 with prefix length 1: total samples: 319, Unique samples: 1
Prompt 1 with prefix length 511: total samples: 290, Unique samples: 1
Prompt 2 with prefix length 2048: total samples: 335, Unique samples: 1
Prompt 3 with prefix length 4097: total samples: 331, Unique samples: 1

Modifications

Fixed DpPaddingMode to SUM_LEN when launch with --enable-deterministic-inference.

Checklist

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @GuoweiWangU, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves an issue where enabling deterministic inference failed to produce consistent results when dp_attention was also active, particularly with Mixture-of-Experts (MoE) models. The fix ensures that the DpPaddingMode is consistently set to SUM_LEN when deterministic inference is requested, thereby guaranteeing predictable and reproducible output for these configurations.

Highlights

  • Deterministic Inference Fix: Ensures deterministic results for MoE models when --enable-deterministic-inference is used in conjunction with --enable-dp-attention by explicitly setting DpPaddingMode to SUM_LEN.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses an issue where deterministic inference was failing when DP attention is enabled. The fix correctly forces the DpPaddingMode to SUM_LEN in deterministic mode. The changes are logical and well-targeted. I've added a couple of suggestions to refactor the new conditional checks for improved conciseness and maintainability.

Comment thread python/sglang/srt/layers/dp_attention.py Outdated
Comment thread python/sglang/srt/layers/dp_attention.py Outdated
cls, is_extend_in_batch, global_num_tokens: List[int]
) -> DpPaddingMode:
if is_extend_in_batch:
if _DETERMINISTIC_INFERENCE or is_extend_in_batch:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason of fixing padding mdoe to SUM_LEN rather than MAX_LEN?

Copy link
Copy Markdown
Contributor Author

@GuoweiWangU GuoweiWangU Sep 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that SUM_LEN is used in is_extend_in_batch mode, for consistency, we choose to use SUM_LEN throughout the entire process. If get_dp_padding_mode and get_default_mode_in_cuda_graph are forced to be configured as MAX_LEN, the test results will still not be deterministic.

Prompt 0 with prefix length 1: total samples: 314, Unique samples: 4
Prompt 1 with prefix length 511: total samples: 300, Unique samples: 13
Prompt 2 with prefix length 2048: total samples: 344, Unique samples: 5
Prompt 3 with prefix length 4097: total samples: 317, Unique samples: 25

@Fridge003
Copy link
Copy Markdown
Collaborator

Fridge003 commented Sep 28, 2025

@GuoweiWangU Also can you provide the result of testing determinism with prefix mode?

python3 -m sglang.test.test_deterministic --test-mode prefix

@GuoweiWangU
Copy link
Copy Markdown
Contributor Author

@GuoweiWangU Also can you provide the result of testing determinism with prefix mode?

python3 -m sglang.test.test_deterministic --test-mode prefix

Results updated

# TODO(kkhuang-amd): noqa, temporary work-around for rocm 7.0.0 alpha
# it can be safely removed later, once RCCL fixed
if _USE_ROCM700A_WA:
if _DETERMINISTIC_INFERENCE or _USE_ROCM700A_WA:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SUM_LEN is likely to hurt performance here. Could you double check if the change in this line is needed?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is to fix the problem of deterministic inference. Here, it is to ensure that SUM_LEN is used in all calculation processes (existing extend processes use SUM_LEN). Some performance loss is unavoidable.

Copy link
Copy Markdown
Collaborator

@Fridge003 Fridge003 Oct 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can temporarily apply this strategy, and try to optimize its performance later. @ch-wan

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants