Skip to content

[AMD] Fp8 prefill integration with radix cache path for dpsk models#20187

Merged
HaiShaw merged 8 commits intosgl-project:mainfrom
1am9trash:fp8_prefill_integration_with_radix_cache_path
Mar 10, 2026
Merged

[AMD] Fp8 prefill integration with radix cache path for dpsk models#20187
HaiShaw merged 8 commits intosgl-project:mainfrom
1am9trash:fp8_prefill_integration_with_radix_cache_path

Conversation

@1am9trash
Copy link
Copy Markdown
Collaborator

@1am9trash 1am9trash commented Mar 9, 2026

Motivation

Previously, fp8 prefill attention on dpsk models did not cover the radix-cache path.

Modifications

This PR enables fp8 prefill attention for the radix-cache path.
To reduce extra element-wise casts, it also uses fused_gemm_afp4wfp4_split_cat, following the same design principle as the existing fp8 prefill path.

(Not cover fp8 fused gemm yet.)

Accuracy Tests

Tested with radix-cache on/off and fp8/bf16 prefill.
image

Benchmarking and Profiling

Compared fp8 vs. bf16 prefill on dspk-r1-mxfp4 with 70k/200 and radix-cache on.

server cmd:

SGLANG_AITER_FP8_PREFILL_ATTN=1 python3 -m sglang.launch_server --model-path DeepSeek-R1-MXFP4-Preview/ --tensor-parallel-size 8 --trust-remote-code --chunked-prefill-size 131072 --host 0.0.0.0 --port 9000 --log-requests --mem-fraction-static 0.8 --max-running-requests 64 --kv-cache-dtype fp8_e4m3 --attention-backend aiter --speculative-algorithm EAGLE --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4

client cmd:

python3 -m sglang.bench_serving --host localhost --port 9000 --model DeepSeek-R1-MXFP4-Preview/ --dataset-name random --random-input 70000 --random-output 200 --random-range-ratio 1.0 --num-prompt [cc * 8] --max-concurrency [cc]

Speed on MI355 (cc=1-16):

  • total tok/s: 1.17×-1.26×
  • ttft: 1.07×-1.29×

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request extends the functionality of FP8 prefill attention to include the radix-cache path for dpsk models, addressing a previous limitation. By integrating this capability and optimizing the K/V generation process with a fused GEMM operation, the changes aim to enhance the overall efficiency and performance of the attention mechanism, leading to faster inference speeds.

Highlights

  • FP8 Prefill Attention for Radix Cache: Enabled FP8 prefill attention for the radix-cache path, which was previously unsupported in dpsk models.
  • Optimized K/V Generation: Implemented fused_gemm_afp4wfp4_split_cat to reduce extra element-wise casts during K/V generation, aligning with existing FP8 prefill optimizations.
  • Performance Improvements: Demonstrated significant speedups on MI355, with total tokens/second increasing by 1.17x-1.26x and time-to-first-token by 1.07x-1.29x compared to BF16 prefill.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/layers/attention/aiter_backend.py
    • Added a new method mla_fp8_prefill_attn to encapsulate the logic for FP8 prefill attention, including type casting and kernel calls.
    • Updated the init_forward_metadata method to correctly retrieve and utilize kv_indptr for metadata generation.
    • Modified the call to make_mla_prefill_ps_meta_data to use kv_indptr instead of qo_indptr for accurate metadata setup.
    • Adjusted the calculation of total_s to directly use forward_batch.seq_lens_sum for efficiency.
    • Refactored the forward_extend method to delegate FP8 prefill attention computation to the newly introduced mla_fp8_prefill_attn method.
    • Implemented conditional logic in forward_extend to use fused_gemm_afp4wfp4_split_cat for optimized K/V generation when FP8 prefill is active and weights are in torch.uint8 format.
Activity
  • No specific activity (comments, reviews, or progress updates) has been recorded for this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request integrates Fp8 prefill with the radix cache path for dpsk models, which is a valuable performance enhancement. The changes are well-structured, primarily involving the addition of a new mla_fp8_prefill_attn function by refactoring existing code and extending it to handle the radix cache path. A new optimization for MXFP4 weights using a fused kernel is also introduced, which is a nice touch. The logic appears sound. I have one minor suggestion to improve code conciseness.

Comment thread python/sglang/srt/layers/attention/aiter_backend.py
@1am9trash 1am9trash changed the title [AMD] Fp8 prefill integration with radix cache path for dpsk models [DO NOT MERGE][AMD] Fp8 prefill integration with radix cache path for dpsk models Mar 9, 2026
@1am9trash 1am9trash changed the title [DO NOT MERGE][AMD] Fp8 prefill integration with radix cache path for dpsk models [AMD] Fp8 prefill integration with radix cache path for dpsk models Mar 9, 2026
@HaiShaw
Copy link
Copy Markdown
Collaborator

HaiShaw commented Mar 10, 2026

/tag-and-rerun-ci

@HaiShaw
Copy link
Copy Markdown
Collaborator

HaiShaw commented Mar 10, 2026

@1am9trash Let's extend fp8 prefill attention to (in follow-up PRs):

  • other MHA models:
  • gfx942

@HaiShaw HaiShaw merged commit 6407891 into sgl-project:main Mar 10, 2026
143 of 164 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants