Skip to content

[export] set enable_gqa in export flash->math decomp#158604

Closed
pianpwk wants to merge 1 commit intomainfrom
export-D78524147
Closed

[export] set enable_gqa in export flash->math decomp#158604
pianpwk wants to merge 1 commit intomainfrom
export-D78524147

Conversation

@pianpwk
Copy link
Contributor

@pianpwk pianpwk commented Jul 17, 2025

Differential Revision: D78524147

For scaled_dot_product_attention(..., enable_gqa=True):

This assumes GQA for seqlen mismatches in the export decomp, setting enable_gqa = <q seqlen> != <kv seqlen>, relying on prior backend checks to raise on invalid input shapes.

@pytorch-bot
Copy link

pytorch-bot bot commented Jul 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/158604

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit ebeb0a9 with merge base 82f8e04 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D78524147

facebook-github-bot pushed a commit that referenced this pull request Jul 18, 2025
Summary:

For `scaled_dot_product_attention(..., enable_gqa=True)`:
- the Math backend passes the flag through, performing the extra [KV broadcast](https://github.com/pytorch/pytorch/blob/6e07d6a0ff386d99d8c2f1d25978b0683988a4cb/aten/src/ATen/native/transformers/attention.cpp#L902) if set to True
- the Flash backend has no flag, and relies on correct indexing in the C++ kernel
- Export used to default to Math for `enable_gqa=True`, but #157893 landed and enabled Flash. At the same time, there's an export-only [decomp](https://github.com/pytorch/pytorch/blob/6e07d6a0ff386d99d8c2f1d25978b0683988a4cb/torch/_decomp/decompositions.py#L4968) redirecting flash -> math, calling with `enable_gqa` unset, because that info isn't available. This led to https://fb.workplace.com/groups/1028545332188949/posts/1264609398582540 crashing, calling the Math non-GQA variant, with GQA inputs.

This assumes GQA for seqlen mismatches in the export decomp, setting `enable_gqa = <q seqlen> != <kv seqlen>`, relying on prior backend checks to raise on invalid input shapes.

Test Plan:
test_export

Rollback Plan:

Differential Revision: D78524147
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D78524147

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 23, 2025
Summary:

For `scaled_dot_product_attention(..., enable_gqa=True)`:
- the Math backend passes the flag through, performing the extra [KV broadcast](https://github.com/pytorch/pytorch/blob/6e07d6a0ff386d99d8c2f1d25978b0683988a4cb/aten/src/ATen/native/transformers/attention.cpp#L902) if set to True
- the Flash backend has no flag, and relies on correct indexing in the C++ kernel
- Export used to default to Math for `enable_gqa=True`, but #157893 landed and enabled Flash. At the same time, there's an export-only [decomp](https://github.com/pytorch/pytorch/blob/6e07d6a0ff386d99d8c2f1d25978b0683988a4cb/torch/_decomp/decompositions.py#L4968) redirecting flash -> math, calling with `enable_gqa` unset, because that info isn't available. This led to https://fb.workplace.com/groups/1028545332188949/posts/1264609398582540 crashing, calling the Math non-GQA variant, with GQA inputs.

This assumes GQA for seqlen mismatches in the export decomp, setting `enable_gqa = <q seqlen> != <kv seqlen>`, relying on prior backend checks to raise on invalid input shapes.

Test Plan:
test_export

Rollback Plan:

Reviewed By: angelayi

Differential Revision: D78524147
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D78524147

@facebook-github-bot
Copy link
Contributor

@pytorchbot merge

(Initiating merge automatically since Phabricator Diff has merged)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

yangw-dev pushed a commit that referenced this pull request Aug 1, 2025
Differential Revision: D78524147

For `scaled_dot_product_attention(..., enable_gqa=True)`:
- the Math backend passes the flag through, performing the extra [KV broadcast](https://github.com/pytorch/pytorch/blob/6e07d6a0ff386d99d8c2f1d25978b0683988a4cb/aten/src/ATen/native/transformers/attention.cpp#L902) if set to True
- the Flash backend has no flag, and relies on correct indexing in the C++ kernel
- Export used to default to Math for `enable_gqa=True`, but #157893 landed and enabled Flash. At the same time, there's an export-only [decomp](https://github.com/pytorch/pytorch/blob/6e07d6a0ff386d99d8c2f1d25978b0683988a4cb/torch/_decomp/decompositions.py#L4968) redirecting flash -> math, calling with `enable_gqa` unset, because that info isn't available. This led to https://fb.workplace.com/groups/1028545332188949/posts/1264609398582540 crashing, calling the Math non-GQA variant, with GQA inputs.

This assumes GQA for seqlen mismatches in the export decomp, setting `enable_gqa = <q seqlen> != <kv seqlen>`, relying on prior backend checks to raise on invalid input shapes.

Pull Request resolved: #158604
Approved by: https://github.com/angelayi, https://github.com/drisspg
@github-actions github-actions bot deleted the export-D78524147 branch August 24, 2025 02:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants