[export] set enable_gqa in export flash->math decomp#158604
Closed
[export] set enable_gqa in export flash->math decomp#158604
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/158604
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit ebeb0a9 with merge base 82f8e04 ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Contributor
|
This pull request was exported from Phabricator. Differential Revision: D78524147 |
92e4661 to
8e1d5d9
Compare
facebook-github-bot
pushed a commit
that referenced
this pull request
Jul 18, 2025
Summary: For `scaled_dot_product_attention(..., enable_gqa=True)`: - the Math backend passes the flag through, performing the extra [KV broadcast](https://github.com/pytorch/pytorch/blob/6e07d6a0ff386d99d8c2f1d25978b0683988a4cb/aten/src/ATen/native/transformers/attention.cpp#L902) if set to True - the Flash backend has no flag, and relies on correct indexing in the C++ kernel - Export used to default to Math for `enable_gqa=True`, but #157893 landed and enabled Flash. At the same time, there's an export-only [decomp](https://github.com/pytorch/pytorch/blob/6e07d6a0ff386d99d8c2f1d25978b0683988a4cb/torch/_decomp/decompositions.py#L4968) redirecting flash -> math, calling with `enable_gqa` unset, because that info isn't available. This led to https://fb.workplace.com/groups/1028545332188949/posts/1264609398582540 crashing, calling the Math non-GQA variant, with GQA inputs. This assumes GQA for seqlen mismatches in the export decomp, setting `enable_gqa = <q seqlen> != <kv seqlen>`, relying on prior backend checks to raise on invalid input shapes. Test Plan: test_export Rollback Plan: Differential Revision: D78524147
Contributor
|
This pull request was exported from Phabricator. Differential Revision: D78524147 |
angelayi
approved these changes
Jul 23, 2025
Summary: For `scaled_dot_product_attention(..., enable_gqa=True)`: - the Math backend passes the flag through, performing the extra [KV broadcast](https://github.com/pytorch/pytorch/blob/6e07d6a0ff386d99d8c2f1d25978b0683988a4cb/aten/src/ATen/native/transformers/attention.cpp#L902) if set to True - the Flash backend has no flag, and relies on correct indexing in the C++ kernel - Export used to default to Math for `enable_gqa=True`, but #157893 landed and enabled Flash. At the same time, there's an export-only [decomp](https://github.com/pytorch/pytorch/blob/6e07d6a0ff386d99d8c2f1d25978b0683988a4cb/torch/_decomp/decompositions.py#L4968) redirecting flash -> math, calling with `enable_gqa` unset, because that info isn't available. This led to https://fb.workplace.com/groups/1028545332188949/posts/1264609398582540 crashing, calling the Math non-GQA variant, with GQA inputs. This assumes GQA for seqlen mismatches in the export decomp, setting `enable_gqa = <q seqlen> != <kv seqlen>`, relying on prior backend checks to raise on invalid input shapes. Test Plan: test_export Rollback Plan: Reviewed By: angelayi Differential Revision: D78524147
8e1d5d9 to
ebeb0a9
Compare
Contributor
|
This pull request was exported from Phabricator. Differential Revision: D78524147 |
drisspg
approved these changes
Jul 24, 2025
Contributor
|
@pytorchbot merge (Initiating merge automatically since Phabricator Diff has merged) |
Collaborator
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
yangw-dev
pushed a commit
that referenced
this pull request
Aug 1, 2025
Differential Revision: D78524147 For `scaled_dot_product_attention(..., enable_gqa=True)`: - the Math backend passes the flag through, performing the extra [KV broadcast](https://github.com/pytorch/pytorch/blob/6e07d6a0ff386d99d8c2f1d25978b0683988a4cb/aten/src/ATen/native/transformers/attention.cpp#L902) if set to True - the Flash backend has no flag, and relies on correct indexing in the C++ kernel - Export used to default to Math for `enable_gqa=True`, but #157893 landed and enabled Flash. At the same time, there's an export-only [decomp](https://github.com/pytorch/pytorch/blob/6e07d6a0ff386d99d8c2f1d25978b0683988a4cb/torch/_decomp/decompositions.py#L4968) redirecting flash -> math, calling with `enable_gqa` unset, because that info isn't available. This led to https://fb.workplace.com/groups/1028545332188949/posts/1264609398582540 crashing, calling the Math non-GQA variant, with GQA inputs. This assumes GQA for seqlen mismatches in the export decomp, setting `enable_gqa = <q seqlen> != <kv seqlen>`, relying on prior backend checks to raise on invalid input shapes. Pull Request resolved: #158604 Approved by: https://github.com/angelayi, https://github.com/drisspg
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Differential Revision: D78524147
For
scaled_dot_product_attention(..., enable_gqa=True):enable_gqa=True, but [CPU] Support GQA for flash attention #157893 landed and enabled Flash. At the same time, there's an export-only decomp redirecting flash -> math, calling withenable_gqaunset, because that info isn't available. This led to https://fb.workplace.com/groups/1028545332188949/posts/1264609398582540 crashing, calling the Math non-GQA variant, with GQA inputs.This assumes GQA for seqlen mismatches in the export decomp, setting
enable_gqa = <q seqlen> != <kv seqlen>, relying on prior backend checks to raise on invalid input shapes.