Skip to content

Fix cutlass moe accuracy drop caused by attention UB from DP padding mode#10414

Merged
zhyncs merged 5 commits intosgl-project:mainfrom
fzyzcjy:feat/hack_pad
Sep 14, 2025
Merged

Fix cutlass moe accuracy drop caused by attention UB from DP padding mode#10414
zhyncs merged 5 commits intosgl-project:mainfrom
fzyzcjy:feat/hack_pad

Conversation

@fzyzcjy
Copy link
Copy Markdown
Collaborator

@fzyzcjy fzyzcjy commented Sep 14, 2025

Motivation

Bug description

  • cutlass_moe nvfp4 has bad accuracy
  • changing multiple moe backends does fix the accuracy (e.g. cutedsl_moe, fp8, all good)
  • changing multiple attention backends does not fix the accuracy (e.g. trtllm_mla, cutlass_mla, flashinfer, all bad)

Direct cause

  • attention core function is provided padded inputs and non-padded seq lens, which is potentially UB

Root cause

  • DP padding mode is set to max instead of sum, which pads tokens

Potential affects

  • anyone using DP attention + any moe as long as it uses all-gather (or all-reduce or reduce-scatter) based mechanism, b/c that will make the max padding mode really pad tokens

Accuracy test and issue that this PR fixes

old text

I will do more experiments when having time to see what is going on and give a better fix. currently it is just a temporary workaround to check whether accuracy issue is caused by this.

ongoing work at feat/hack_ac6131

EDIT: a better workaround is as follows. seems attn metadata do not understand padding while we do have padding to pass to it.

image

EDIT: maybe this, I will cherry pick after testing e2e long benchmarks

image

Modifications

Accuracy Tests

Benchmarking and Profiling

Checklist

(cherry picked from commit 0f6f939)
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @fzyzcjy, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a temporary workaround to address an observed accuracy drop, specifically within the DeepseekV2 model. The changes involve tracking the original number of tokens before padding and then explicitly zeroing out the padded sections of intermediate tensors (hidden_states and residual) during the model's forward pass in 'extend' mode. This is a diagnostic or short-term fix to isolate or mitigate potential issues related to padding effects on model accuracy, indicated by the [DO NOT MERGE] tag and hack_ prefixes.

Highlights

  • Temporary Field Addition: A new attribute, hack_num_tokens_before_pad, has been added to the ForwardBatch class in forward_batch_info.py to temporarily store the count of actual tokens before any padding is applied.
  • Padding Length Tracking: The prepare_mlp_sync_batch method now records the original length of input_ids into the hack_num_tokens_before_pad field right before the padding operation.
  • Padded Tensor Zeroing: A helper function hack_setpad was introduced in deepseek_v2.py to explicitly zero out the padded portions of hidden_states and residual tensors. This function is conditionally applied at various stages within the DeepseekV2ForCausalLM's forward method when the forward_mode is set to 'extend'.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a temporary workaround to address an accuracy drop by zeroing out padding in hidden states and residuals for the Deepseek V2 model during the extend phase. While the fix seems correct, the implementation introduces some code quality issues. Specifically, the use of 'hack' in variable and function names should be avoided, and there is significant code duplication in deepseek_v2.py that should be refactored for better maintainability. My review includes suggestions to address these points.

Comment thread python/sglang/srt/models/deepseek_v2.py Outdated
Comment thread python/sglang/srt/model_executor/forward_batch_info.py Outdated
@fzyzcjy fzyzcjy changed the title [DO NOT MERGE] Temporary workaround for accuracy drop [DO NOT MERGE] Fix accuracy drop Sep 14, 2025
@fzyzcjy fzyzcjy changed the title [DO NOT MERGE] Fix accuracy drop Fix accuracy drop Sep 14, 2025
(cherry picked from commit 32c5c5c)
(cherry picked from commit e3b827e)
@fzyzcjy fzyzcjy changed the title Fix accuracy drop Fix accuracy drop caused by attention UB Sep 14, 2025
(cherry picked from commit 975beb7)
@fzyzcjy fzyzcjy changed the title Fix accuracy drop caused by attention UB Fix cutlass moe accuracy drop caused by attention UB Sep 14, 2025
@fzyzcjy fzyzcjy changed the title Fix cutlass moe accuracy drop caused by attention UB Fix cutlass moe accuracy drop caused by attention UB from DP padding mode Sep 14, 2025
@zhyncs zhyncs merged commit 72dfa96 into sgl-project:main Sep 14, 2025
154 of 167 checks passed
@wenscarl wenscarl mentioned this pull request Sep 15, 2025
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants