Skip to content

Add stride check for attn_mask on non-cpu device#158424

Closed
CaoE wants to merge 4 commits intopytorch:mainfrom
CaoE:fix_sdpa
Closed

Add stride check for attn_mask on non-cpu device#158424
CaoE wants to merge 4 commits intopytorch:mainfrom
CaoE:fix_sdpa

Conversation

@CaoE
Copy link
Collaborator

@CaoE CaoE commented Jul 16, 2025

@pytorch-bot
Copy link

pytorch-bot bot commented Jul 16, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/158424

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit ebc237b with merge base a5e6881 (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@CaoE CaoE added ciflow/trunk Trigger trunk jobs on your pull request topic: not user facing topic category labels Jul 16, 2025
@CaoE CaoE requested review from Valentine233 and Copilot July 16, 2025 08:00
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds a stride check for attention mask tensors on non-CPU devices to fix issue #158374. The change ensures that fused attention kernels properly validate that the attention mask has a stride of 1 in the last dimension when running on GPU devices, while allowing more flexibility on CPU.

  • Updates the stride validation logic to include attention mask stride checking for non-CPU devices
  • Adds comprehensive test coverage for attention masks with non-contiguous strides
  • Improves error messaging to include attention mask stride information in debug output

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.

File Description
aten/src/ATen/native/transformers/sdp_utils_cpp.h Adds device-specific stride validation for attention masks and enhances error messaging
test/inductor/test_fused_attention.py Adds test case for attention mask with non-unit stride in last dimension
Comments suppressed due to low confidence (2)

aten/src/ATen/native/transformers/sdp_utils_cpp.h:511

  • [nitpick] The variable name 'mask_stride_check' is ambiguous. Consider renaming to 'mask_stride_valid' or 'is_mask_stride_compatible' to better indicate it represents a boolean validation result.
  bool mask_stride_check = is_cpu ? true : mask_stride_equal_1;

aten/src/ATen/native/transformers/sdp_utils_cpp.h:514

  • [nitpick] The variable name 'epilogue_message' is unclear. Consider renaming to 'additional_error_info' or 'mask_error_details' to better describe its purpose of providing additional error message content.
      std::ostringstream epilogue_message;

@CaoE CaoE requested a review from Valentine233 July 16, 2025 08:42
@CaoE CaoE marked this pull request as ready for review July 16, 2025 08:43
Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test here:

class TestSDPAFailureModes(NNTestCase):

and ensure that this error is raised

@CaoE
Copy link
Collaborator Author

CaoE commented Jul 17, 2025

Can you add a test here:

class TestSDPAFailureModes(NNTestCase):

and ensure that this error is raised

Added a test for this error message.


@onlyCUDA
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Efficient Attention was not built for this system")
@parametrize("kernel", [SDPBackend.EFFICIENT_ATTENTION])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you also add cudnn

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

@CaoE
Copy link
Collaborator Author

CaoE commented Jul 18, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

atalman pushed a commit that referenced this pull request Jul 18, 2025
tvukovic-amd pushed a commit to ROCm/pytorch that referenced this pull request Aug 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

torch 2.8 RC regression - part 1

7 participants