Skip to content

[ROCm] Sdpa dropout fix#174708

Open
alugorey wants to merge 5 commits intopytorch:mainfrom
alugorey:sdpa_dropout_fix
Open

[ROCm] Sdpa dropout fix#174708
alugorey wants to merge 5 commits intopytorch:mainfrom
alugorey:sdpa_dropout_fix

Conversation

@alugorey
Copy link
Contributor

@alugorey alugorey commented Feb 10, 2026

Fixes a device side memory access fault when dropout is used with CK sdpa backend. There was an bug where the philox seed/offset tensors were not being placed on the GPU correctly thus resulting in a memory access fault.

This PR is an attempted reland of #154864 which was reverted a couple times due to internal tests breaking. This was because the output tensors for mha_fwd were mistakenly changed causing dynamo to complain. In the newest commit, we changed the output tensors back to what they were originally while preserving the bug fix for the device side memory assertion. This allowed the dynamo checks to succeed and avoided the original error.

cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @jataylo @hongxiayang @naromero77amd @pragupta @jerrymannil @xinyazhang

@pytorch-bot
Copy link

pytorch-bot bot commented Feb 10, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/174708

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit bdddea3 with merge base e45dfba (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 module: rocm AMD GPU support for Pytorch labels Feb 10, 2026
@pytorch-bot
Copy link

pytorch-bot bot commented Feb 10, 2026

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@meta-codesync
Copy link

meta-codesync bot commented Feb 10, 2026

@drisspg has imported this pull request. If you are a Meta employee, you can view this in D92865015.

@bdhirsh bdhirsh requested a review from jeffdaily February 10, 2026 23:20
@bdhirsh bdhirsh added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 10, 2026
@alugorey
Copy link
Contributor Author

@eqy @malfet

Can we get eyes on this? unrelated failures. This was the one that failed some Meta internal tests. @drisspg gave us a reproducer last week and this snapshot of the commit fixed it.

@malfet
Copy link
Contributor

malfet commented Feb 23, 2026

@alugorey can you provide a bit more description on the PR what exactly it fixes and what changes were made after initial attempt got reverted

@alugorey
Copy link
Contributor Author

hi @malfet I updated the PR description. Please take a look

@alugorey
Copy link
Contributor Author

@malfet one more look? :)

@drisspg
Copy link
Contributor

drisspg commented Feb 27, 2026

@claude explain to me what this PR is doing and what was broken before, does the regression test adequately test this

@claude
Copy link

claude bot commented Feb 27, 2026

Claude finished @drisspg's task in 1m 17s —— View job


Analysis Complete

  • Understand the bug that was fixed
  • Review the code changes
  • Analyze the regression test
  • Provide summary and assessment

What This PR Does

This PR fixes a device-side memory access fault when using dropout with the CK (Composable Kernel) SDPA backend on ROCm/AMD GPUs.

The Bug (Before)

In mha_fwd_ck.hip, the philox seed/offset tensors used for dropout RNG were being created incorrectly:

// OLD (buggy) code - mha_fwd_ck.hip:287-298
auto rng_state = at::empty({2}, opts.dtype(at::kLong));  // Created on CPU!
auto rng_state_ptr = reinterpret_cast<uint64_t*>(rng_state.data_ptr());

// Later passed to GPU kernel - MEMORY FAULT!
auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1);

The issue: opts comes from q.options() without explicitly specifying device, and the at::empty call wasn't forcing CUDA placement. When rng_state_ptr was dereferenced on the GPU, it was actually pointing to CPU memory, causing the memory access fault.

The Fix

The fix explicitly creates rng_state on the GPU:

// NEW (fixed) code - mha_fwd_ck.hip:291-293
auto rng_state_options = at::TensorOptions().dtype(at::kUInt64).device(at::kCUDA);
auto rng_state = at::zeros({2}, rng_state_options.dtype(at::kUInt64));

Additionally:

  1. The backward pass (mha_bwd_ck.hip) now directly reads philox seed/offset from GPU tensors rather than copying to local variables
  2. The return tensors seed_t and offset_t are now explicitly created on CUDA with the correct dtype (kUInt64)

Why Previous Reverts

The PR description mentions this is a reland of #154864 which was reverted due to internal test failures. The original fix accidentally changed the output tensor structure of mha_fwd. This version preserves the original return signature while fixing the memory bug.


Regression Test Assessment

The test test_flash_attention_vs_math_ref_grads was modified to parametrize over sdpa_backend (["aotriton", "ck"]):

Strengths:

  • ✅ Tests both forward and backward passes with dropout (dropout_p parameter includes non-zero values)
  • ✅ Tests the CK backend explicitly via torch.backends.cuda.preferred_rocm_fa_library(sdpa_backend)
  • ✅ Properly handles CK's different dropout mask format (lines 3866-3870)
  • ✅ Tests gradient correctness against a math reference

Concerns:

  • ⚠️ The test only runs if PLATFORM_SUPPORTS_CK_SDPA is true, which requires the CK backend to be available. This means CI without ROCm won't exercise this path.
  • ⚠️ The test doesn't explicitly verify that the dropout values are on the GPU or that no memory fault occurs - it relies on the test passing without crashing.
  • ⚠️ The dropout_p parametrization includes [0.0, 0.5] but only 0.5 exercises the fixed code path.

Is it adequate?

For a functional regression test, yes - if the bug were present, the test would crash with a memory access fault when running with dropout_p > 0 on the CK backend. However, a more targeted unit test that specifically verifies tensor device placement would be more robust for catching this class of bug in the future.


Summary

Aspect Assessment
Bug identified ✅ Clear - philox RNG tensors not on GPU
Fix correctness ✅ Correctly forces CUDA device placement
Test coverage ⚠️ Adequate for regression, but only exercises on ROCm with CK available
Risk of re-regression Low - the fix is straightforward and explicit

@drisspg
Copy link
Contributor

drisspg commented Feb 27, 2026

I dont really trust claude here since opts should already have the correct device can you explain why your code changes fix the issue?

@jeffdaily
Copy link
Collaborator

I dont really trust claude here since opts should already have the correct device can you explain why your code changes fix the issue?

@alugorey IIRC it was a Tensor lifetime issue. Previously the seed/offset were referencing ptrs to Tensors that went out of scope, but this PR guarantees proper lifetime management of seed/offset now.

drisspg
drisspg previously approved these changes Mar 5, 2026
Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay lets try again

@alugorey
Copy link
Contributor Author

alugorey commented Mar 6, 2026

Hi @drisspg, Jeff's assessment was right. it was a tensor lifetime issue. This is ready to land 👍

@jeffdaily
Copy link
Collaborator

@pytorchbot merge -f "unrelated lint infra, other lint is passing; others are known flaky"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorch-bot
Copy link

pytorch-bot bot commented Mar 6, 2026

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Mar 6, 2026
@pytorch-bot pytorch-bot bot dismissed drisspg’s stale review March 6, 2026 21:48

This PR was reopened (likely due to being reverted), so your approval was removed. Please request another review.

@pytorch-bot pytorch-bot bot removed the ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 label Mar 6, 2026
@pytorch-bot
Copy link

pytorch-bot bot commented Mar 6, 2026

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@alugorey
Copy link
Contributor Author

alugorey commented Mar 9, 2026

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Mar 9, 2026
@jithunnair-amd jithunnair-amd added ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 ciflow/rocm-mi355 Trigger "default" config CI on ROCm MI355 runners ciflow/trunk Trigger trunk jobs on your pull request labels Mar 9, 2026
jeffdaily
jeffdaily previously approved these changes Mar 10, 2026
@jeffdaily
Copy link
Collaborator

@pytorchbot merge -f "resolved the doc errors that caused the prior revert, all other CI is passing"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@liangel-02
Copy link
Contributor

@pytorchbot revert -m "sorry, need to revert this due to #172246 being reverted" -c ghfirst

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@alugorey your PR has been successfully reverted.

pytorchmergebot added a commit that referenced this pull request Mar 10, 2026
This reverts commit 3f60bc4.

Reverted #174708 on behalf of https://github.com/liangel-02 due to sorry, need to revert this due to #172246 being reverted ([comment](#174708 (comment)))
@pytorch-bot pytorch-bot bot dismissed jeffdaily’s stale review March 10, 2026 20:26

This PR was reopened (likely due to being reverted), so your approval was removed. Please request another review.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 ciflow/rocm-mi355 Trigger "default" config CI on ROCm MI355 runners ciflow/trunk Trigger trunk jobs on your pull request Merged module: rocm AMD GPU support for Pytorch open source Reverted topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants