[ONNX] Fix scaled_dot_product_attention with float scale#135594
[ONNX] Fix scaled_dot_product_attention with float scale#135594titaiwangms wants to merge 1 commit intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/135594
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit f982235 with merge base dfb2b66 ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
BROKEN TRUNK - The following job failed but was present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| ), "is_causal and attn_mask cannot be set at the same time" | ||
| assert not enable_gqa, "conversion of scaled_dot_product_attention not implemented if enable_gqa is True" | ||
|
|
||
| scale = symbolic_helper._maybe_get_const(scale, "f") |
There was a problem hiding this comment.
Just making sure: does it work for fp16?
There was a problem hiding this comment.
The fix seem to work with my FP16 case, I see no warnings either.
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
@titaiwangms did you want to cherry pick this as well? I can go under category (2) or (3) I think. |
Sure! Looks like people need this. |
) Fixes pytorch#125158 Pull Request resolved: pytorch#135594 Approved by: https://github.com/justinchuby (cherry picked from commit e48ee2c)
Fixes #125158 Pull Request resolved: #135594 Approved by: https://github.com/justinchuby (cherry picked from commit e48ee2c)
) Fixes pytorch#125158 Pull Request resolved: pytorch#135594 Approved by: https://github.com/justinchuby
Fixes #125158