Add torch.compile compatibility to FP8 SDPA using FA3#172622
Add torch.compile compatibility to FP8 SDPA using FA3#172622howardzhang-cv wants to merge 4 commits intogh/howardzhang-cv/7/basefrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/172622
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 7cb5426 with merge base 32642ba ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| sdpa_constraint, | ||
| warn=False, | ||
| ) | ||
| make_fallback( |
There was a problem hiding this comment.
does the constraint work for the FP8 V layout ?
There was a problem hiding this comment.
I think there is actually a section in this file on the layout constraints needed for inductor we should also apply for this overload
| ) | ||
| # Directly call the internal flash attention operator which has descale support | ||
| result = torch._scaled_dot_product_flash_attention( | ||
| # Use the .low_p OpOverload directly for better torch.compile compatibility |
| seed = torch.empty((2), dtype=torch.uint64, device="meta") | ||
| offset = torch.empty((), dtype=torch.uint64, device="meta") | ||
|
|
||
| return ( |
There was a problem hiding this comment.
how much of this is resuable from the other meta funcs, can we just call them directly here?
| seqused_k: Tensor | None = None, | ||
| alibi_slopes: Tensor | None = None, | ||
| ): | ||
| print( |
Summary: Added meta registration for new scaled_dot_product_flash_attention.low_p overload Added inductor lowering fallback for new overload Directly call op overload in _scaled_dot_product_attention_fp8 instead of python builtin function call ghstack-source-id: 9605c0f Pull-Request: pytorch/pytorch#172622
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Summary: Added meta registration for new scaled_dot_product_flash_attention.low_p overload Added inductor lowering fallback for new overload Directly call op overload in _scaled_dot_product_attention_fp8 instead of python builtin function call ghstack-source-id: baa029c Pull-Request: pytorch/pytorch#172622
Stack from ghstack (oldest at bottom):
Summary:
Added meta registration for new scaled_dot_product_flash_attention.low_p overload
Added inductor lowering fallback for new overload
Directly call op overload in _scaled_dot_product_attention_fp8 instead of python builtin function call
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo