Add stride check for attn_mask on non-cpu device#158424
Add stride check for attn_mask on non-cpu device#158424CaoE wants to merge 4 commits intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/158424
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit ebc237b with merge base a5e6881 ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Pull Request Overview
This PR adds a stride check for attention mask tensors on non-CPU devices to fix issue #158374. The change ensures that fused attention kernels properly validate that the attention mask has a stride of 1 in the last dimension when running on GPU devices, while allowing more flexibility on CPU.
- Updates the stride validation logic to include attention mask stride checking for non-CPU devices
- Adds comprehensive test coverage for attention masks with non-contiguous strides
- Improves error messaging to include attention mask stride information in debug output
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
| aten/src/ATen/native/transformers/sdp_utils_cpp.h | Adds device-specific stride validation for attention masks and enhances error messaging |
| test/inductor/test_fused_attention.py | Adds test case for attention mask with non-unit stride in last dimension |
Comments suppressed due to low confidence (2)
aten/src/ATen/native/transformers/sdp_utils_cpp.h:511
- [nitpick] The variable name 'mask_stride_check' is ambiguous. Consider renaming to 'mask_stride_valid' or 'is_mask_stride_compatible' to better indicate it represents a boolean validation result.
bool mask_stride_check = is_cpu ? true : mask_stride_equal_1;
aten/src/ATen/native/transformers/sdp_utils_cpp.h:514
- [nitpick] The variable name 'epilogue_message' is unclear. Consider renaming to 'additional_error_info' or 'mask_error_details' to better describe its purpose of providing additional error message content.
std::ostringstream epilogue_message;
drisspg
left a comment
There was a problem hiding this comment.
Can you add a test here:
pytorch/test/test_transformers.py
Line 1498 in 4805a6e
and ensure that this error is raised
Added a test for this error message. |
test/test_transformers.py
Outdated
|
|
||
| @onlyCUDA | ||
| @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Efficient Attention was not built for this system") | ||
| @parametrize("kernel", [SDPBackend.EFFICIENT_ATTENTION]) |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Fixes pytorch#158374 Pull Request resolved: pytorch#158424 Approved by: https://github.com/Valentine233, https://github.com/drisspg, https://github.com/atalman
Add stride check for attn_mask on non-cpu device (#158424) Fixes #158374 Pull Request resolved: #158424 Approved by: https://github.com/Valentine233, https://github.com/drisspg, https://github.com/atalman
Add stride check for attn_mask on non-cpu device (pytorch#158424) Fixes pytorch#158374 Pull Request resolved: pytorch#158424 Approved by: https://github.com/Valentine233, https://github.com/drisspg, https://github.com/atalman
Fixes #158374
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben