perf(sgl-kernel): expose get_scheduler_metadata for FA3 decode optimization#21103
Conversation
…zation Register mha_fwd_get_scheduler_metadata as a torch op in sgl_kernel namespace. This function precomputes FA3 tile scheduling metadata so that the prepare_varlen_num_blocks kernel does not need to run per-layer during decode. The C++ symbol already exists in flash_ops.so (compiled from sgl-attn) but was not exposed as a torch op. Changes: - sgl_flash_kernel_ops.h: declare mha_fwd_get_scheduler_metadata - flash_extension.cc: register sgl_kernel.get_scheduler_metadata torch op - flash_attn.py: add Python get_scheduler_metadata() wrapper This is part 1 of 2: sgl-kernel changes only (backward compatible). Part 2 (sglang Python changes) will use this op when available.
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a performance optimization for Flash Attention v3 (FA3) decode operations within the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request exposes a new operation get_scheduler_metadata to optimize FA3 decoding by precomputing tile scheduling. The changes to the C++ source and header files are well-implemented and consistent with the existing codebase. However, the new Python wrapper for this operation in sgl-kernel/python/sgl_kernel/flash_attn.py is incomplete. It omits several parameters that are available in the C++ operation. This could lead to incorrect behavior or limit the utility of the function, especially since its docstring suggests it can be used with functions that rely on these missing parameters. I've provided a suggestion to make the Python wrapper's signature complete.
|
/tag-and-rerun-ci |
cbfb43d to
f216b9d
Compare
f216b9d to
9781ad5
Compare
|
/rerun-failed-ci retry |
Qiaolin-Yu
left a comment
There was a problem hiding this comment.
LGTM. CI Tracked in #20943
Summary
Register
mha_fwd_get_scheduler_metadataas a torch op in sgl_kernel namespace. This function precomputes FA3 tile scheduling metadata so that theprepare_varlen_num_blockskernel does not need to run per-layer during decode.The C++ symbol already exists in
flash_ops.so(compiled from sgl-attn) but was not exposed as a torch op.Changes (3 files, sgl-kernel only)
sgl-kernel/include/sgl_flash_kernel_ops.h: Declaremha_fwd_get_scheduler_metadatasgl-kernel/csrc/flash_extension.cc: Registersgl_kernel.get_scheduler_metadatatorch opsgl-kernel/python/sgl_kernel/flash_attn.py: Add Pythonget_scheduler_metadata()wrapperThis is part 1 of 2: sgl-kernel changes only (backward compatible, no behavior change).
Part 2 (sglang Python changes to use this op): #21104