Skip to content

Add torch.compile compatibility to FP8 SDPA using FA3#172622

Closed
howardzhang-cv wants to merge 4 commits intogh/howardzhang-cv/7/basefrom
gh/howardzhang-cv/7/head
Closed

Add torch.compile compatibility to FP8 SDPA using FA3#172622
howardzhang-cv wants to merge 4 commits intogh/howardzhang-cv/7/basefrom
gh/howardzhang-cv/7/head

Conversation

@howardzhang-cv
Copy link
Copy Markdown
Contributor

@howardzhang-cv howardzhang-cv commented Jan 16, 2026

Stack from ghstack (oldest at bottom):

Summary:
Added meta registration for new scaled_dot_product_flash_attention.low_p overload
Added inductor lowering fallback for new overload
Directly call op overload in _scaled_dot_product_attention_fp8 instead of python builtin function call

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Jan 16, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/172622

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 7cb5426 with merge base 32642ba (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

howardzhang-cv added a commit that referenced this pull request Jan 16, 2026
Summary:
Added meta registration for new scaled_dot_product_flash_attention.low_p overload
Added inductor lowering fallback for new overload
Directly call op overload in _scaled_dot_product_attention_fp8 instead of python builtin function call


ghstack-source-id: 9605c0f
Pull-Request: #172622
@howardzhang-cv howardzhang-cv added the release notes: nn release notes category label Jan 16, 2026
sdpa_constraint,
warn=False,
)
make_fallback(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does the constraint work for the FP8 V layout ?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is actually a section in this file on the layout constraints needed for inductor we should also apply for this overload

Comment thread torch/nn/functional.py Outdated
)
# Directly call the internal flash attention operator which has descale support
result = torch._scaled_dot_product_flash_attention(
# Use the .low_p OpOverload directly for better torch.compile compatibility
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm weird this helps?

Comment thread torch/_meta_registrations.py Outdated
seed = torch.empty((2), dtype=torch.uint64, device="meta")
offset = torch.empty((), dtype=torch.uint64, device="meta")

return (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how much of this is resuable from the other meta funcs, can we just call them directly here?

Comment thread torch/_meta_registrations.py Outdated
seqused_k: Tensor | None = None,
alibi_slopes: Tensor | None = None,
):
print(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ping ping

SergeyTyshkevich pushed a commit to SergeyTyshkevich/chart2 that referenced this pull request Jan 19, 2026
Summary:
Added meta registration for new scaled_dot_product_flash_attention.low_p overload
Added inductor lowering fallback for new overload
Directly call op overload in _scaled_dot_product_attention_fp8 instead of python builtin function call


ghstack-source-id: 9605c0f
Pull-Request: pytorch/pytorch#172622
@mikaylagawarecki mikaylagawarecki removed their request for review January 20, 2026 15:59
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@howardzhang-cv
Copy link
Copy Markdown
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot Bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 22, 2026
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

suncapitalllc007-star pushed a commit to suncapitalllc007-star/pytorch that referenced this pull request Jan 25, 2026
Summary:
Added meta registration for new scaled_dot_product_flash_attention.low_p overload
Added inductor lowering fallback for new overload
Directly call op overload in _scaled_dot_product_attention_fp8 instead of python builtin function call

ghstack-source-id: baa029c
Pull-Request: pytorch/pytorch#172622
@github-actions github-actions Bot deleted the gh/howardzhang-cv/7/head branch February 23, 2026 02:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants