Skip to content

[cuDNN][varlen] Fixes for cuDNN varlen SDPA#172108

Open
eqy wants to merge 11 commits intopytorch:mainfrom
eqy:wipvarlen
Open

[cuDNN][varlen] Fixes for cuDNN varlen SDPA#172108
eqy wants to merge 11 commits intopytorch:mainfrom
eqy:wipvarlen

Conversation

@eqy
Copy link
Collaborator

@eqy eqy commented Jan 9, 2026

Also requires a newer cuDNN backend version so wouldn't pass tests yet

@liangel-02 Would it be OK to use the repro you shared as a test here or is that still considered private?

cc @csarofeen @ptrblck @xwang233 @nWEIdia @msaroufim @jerryzh168 @tinglvv

@eqy eqy requested review from drisspg and liangel-02 January 9, 2026 19:19
@eqy eqy added the module: cudnn Related to torch.backends.cudnn, and CuDNN support label Jan 9, 2026
@eqy eqy requested a review from syed-ahmed as a code owner January 9, 2026 19:19
@eqy eqy added the module: cuda Related to torch.cuda, and CUDA support in general label Jan 9, 2026
@eqy eqy requested a review from Aidyn-A as a code owner January 9, 2026 19:19
@eqy eqy added open source release notes: cudnn module: sdpa All things related to torch.nn.functional.scaled_dot_product_attentiion labels Jan 9, 2026
@pytorch-bot
Copy link

pytorch-bot bot commented Jan 9, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/172108

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit f21f5be with merge base bfed04b (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@liangel-02
Copy link
Contributor

@liangel-02 Would it be OK to use the repro you shared as a test here or is that still considered private?

go for it :)

@eqy
Copy link
Collaborator Author

eqy commented Jan 9, 2026

@pytorchmergebot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased wipvarlen onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout wipvarlen && git pull --rebase)

@bdhirsh bdhirsh added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jan 12, 2026
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system")
@unittest.skipIf(torch.backends.cudnn.version() is None or torch.backends.cudnn.version() < 91800, "cuDNN 9.18.0.64 is needed for correct support")
@parametrize("dtype", [torch.bfloat16, torch.half])
def test_cudnn_attention_varlen(self, dtype):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can move this test to test/test_varlen_attention.py

from typing import Any, Optional

@torch.library.custom_op("torch_attn::_varlen_attn", mutates_args={})
def _varlen_attn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we can test by setting _should_use_cudnn to True in pytorch/torch/nn/attention/varlen.py and modifying test/test_varlen_attention.py to avoid duplication

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry to clarify, i meant we can set _should_use_cudnn = True in the test (i wrote the wrong file lol)


_varlen_attn.register_autograd(_backward, setup_context=_setup_context)

VarlenShape = namedtuple(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This named tuple way of doing things is deprecated. Can you use the subclass API instead that preserves typing info

@eqy
Copy link
Collaborator Author

eqy commented Jan 29, 2026

Sure just FYI need to revisit this after we upgrade to cuDNN 9.18+, as otherwise these tests are expected to fail

auto rag_k_off = cum_seqlen_kv.mul(k.stride(-3));
auto rag_v_off = cum_seqlen_kv.mul(v.stride(-3));
auto rag_o_off = cum_seqlen_q.mul(o.stride(-3));
auto rag_dq_off = cum_seqlen_q.mul(dQ.stride(-3));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@claude can you do a review of this PR? Also it looks like we pass through strides for all tensors but dO is there a possibility the stride info for dO is wrong or are we guaranteed it will match q based off of construction?

@drisspg
Copy link
Contributor

drisspg commented Feb 5, 2026

@claude can you do a review of this PR? Also it looks like we pass through strides for all tensors but dO is there a possibility the stride info for dO is wrong or are we guaranteed it will match q based off of construction?

@claude
Copy link

claude bot commented Feb 5, 2026

Claude finished @drisspg's task in 1m 25s —— View job


PR Review: cuDNN varlen SDPA Fixes

  • Read and understand the changed files
  • Review aten/src/ATen/native/cudnn/MHA.cpp changes
  • Review aten/src/ATen/native/transformers/cuda/attention_backward.cu changes
  • Review torch/nn/attention/varlen.py changes
  • Investigate dO stride handling question
  • Provide comprehensive review feedback

Summary

This PR fixes several issues with the cuDNN variable-length (varlen) SDPA backward pass:

  1. Separate ragged offsets for dO and output tensors - Previously, dO was incorrectly using RAG_O_OFF_ instead of its own ragged offset
  2. Correct stride handling for gradient tensors (dQ, dK, dV) - Previously using strides from input tensors (q, k, v) instead of the output gradient tensors
  3. Fix softmaxstats/logsumexp layout - cuDNN expects T,H,1 but torch/FA convention is H,T - now properly handled via unsqueeze/transpose
  4. Enable cuDNN path on Hopper (SM 9.x) and Blackwell (SM 10.x) - _should_use_cudnn now returns True for these architectures
  5. Use explicit kwargs in Python API calls - Improves readability and reduces positional argument errors

Detailed Review

aten/src/ATen/native/cudnn/MHA.cpp

1. New UIDs for separate ragged offsets (lines 420-429)
The addition of RAG_DQ_OFF, RAG_DK_OFF, RAG_DV_OFF, and RAG_DO_OFF is the right approach. This allows the gradient tensors to have their own ragged offset computations independent of the input tensors.

2. dO stride handling (lines 1296-1306)
This is the key fix for @drisspg's question:

auto do_strides = dO.strides();
auto DO_ = mha_graph->tensor(fe::graph::Tensor_attributes()
                                 .set_ragged_offset(RAG_DO_OFF_)  // Changed from RAG_O_OFF_
                                 .set_uid(DO)
                                 .set_name("DO")
                                 .set_dim({b, h_q, s_q, d_v})
                                 .set_stride(
                                     {INT_MAX,
                                      do_strides[strideidx0],
                                      do_strides[strideidx1],
                                      do_strides[strideidx2]}));

To answer your specific question: Yes, this PR now correctly passes through strides for dO. Previously at line 1259 (old code), DO_ was using RAG_O_OFF_ instead of its own ragged offset. The PR adds RAG_DO_OFF_ and also captures do_strides = dO.strides() to set the correct stride information for the dO tensor.

3. Gradient tensor strides (lines 1308-1335)
The change from using input tensor strides (q_strides, k_strides, v_strides) to output gradient tensor strides (dq_strides, dk_strides, dv_strides) is correct. The gradients may have different memory layouts than their corresponding inputs, and cuDNN needs accurate stride information.

4. softmaxstats layout fix (lines 1524-1532, 1811-1818)
The fix for H,T vs T,H,1 layout is important for compatibility with the Flash Attention convention:

// cuDNN wants T, H, 1, but torch/FA convention is H, T
softmaxstats = at::empty({h_q, q.size(0)}, q.options().dtype(kFloat));
// ...
if (softmaxstats.dim() == 2) {
    softmaxstats_ = softmaxstats.unsqueeze(-1).transpose(0, 1);
}

aten/src/ATen/native/transformers/cuda/attention_backward.cu

The removal of max_seqlen_batch_q and max_seqlen_batch_k in favor of using the passed-in max_q and max_k parameters is correct. For nested tensors, query.size(2) doesn't represent the max sequence length - it's the head dimension in THD format.

torch/nn/attention/varlen.py

1. Enable cuDNN for Hopper/Blackwell (lines 29-35)

def _should_use_cudnn(device_index: int) -> bool:
    major_cap = torch.cuda.get_device_capability(device_index)[0]
    if major_cap == 9 or major_cap == 10:
        return True
    return False

2. Additional checks before using cuDNN (lines 67-69)

head_dim_cudnn_ok = query.shape[-1] % 8 == 0 and value.shape[-1] % 8 == 0
window_size_cudnn_ok = window_size[0] == -1 and window_size[1] == -1
use_cudnn = window_size_cudnn_ok and query.is_cuda and _should_use_cudnn(query.device.index) and head_dim_cudnn_ok

These checks are important - cuDNN has restrictions on head dimensions (must be divisible by 8) and doesn't support sliding window attention.

3. Explicit kwargs in backward call (lines 314-331)
The switch to explicit keyword arguments is good for maintainability.


Potential Issues / Suggestions

  1. Line 301 redundant assignment: In _varlen_attn_backward, line 303 assigns use_cudnn and then line 306 immediately reassigns it with the full condition. The first assignment on line 303 is dead code:
use_cudnn = query.is_cuda and _should_use_cudnn(query.device.index)  # Line 303 - dead code
head_dim_cudnn_ok = query.shape[-1] % 8 == 0 and value.shape[-1] % 8 == 0
window_size_cudnn_ok = window_size[0] == -1 and window_size[1] == -1
use_cudnn = window_size_cudnn_ok and query.is_cuda and _should_use_cudnn(query.device.index) and head_dim_cudnn_ok  # Line 306 - overwrites

Fix this →

  1. philox_offset reuse: In the backward Python call, philox_offset=rng_state is passed with a comment "should be unused". While this works since dropout is disabled, it might be cleaner to pass a separate empty tensor or document this more explicitly in the function signature.

  2. Review comment from @Skylion007: The use of NamedTuple for AuxRequest is noted as deprecated. Consider using the class-based approach instead.


Answer to @drisspg's Question

Also it looks like we pass through strides for all tensors but dO is there a possibility the stride info for dO is wrong or are we guaranteed it will match q based off of construction?

This PR fixes exactly this issue. Before this PR, dO was:

  1. Using RAG_O_OFF_ (output's ragged offset) instead of its own RAG_DO_OFF_
  2. The stride information was implicitly derived

After this PR:

  • dO gets its own ragged offset (RAG_DO_OFF_) computed from cum_seqlen_q.mul(dO_.stride(-3))
  • dO explicitly captures and uses do_strides = dO.strides()

The gradient output tensor's strides are not guaranteed to match q's layout. In fact, there's already handling at lines 1821-1824 that may call permute_to_matching_layout(o, dO_) if the innermost stride isn't 1. This ensures the layout is compatible with cuDNN's requirements.


Overall, this is a solid bug fix PR that addresses real issues with stride/offset handling in the cuDNN varlen backward pass. The changes are well-structured and address the core problem systematically.
| Branch: wipvarlen-53282b5c

@liangel-02
Copy link
Contributor

#172108 (comment)

can we still enable the tests in this PR in test_varlen_attention.py?

@eqy
Copy link
Collaborator Author

eqy commented Feb 5, 2026

@liangel-02 sorry if I am misunderstanding the question but I thought this PR is enabling the tests in that it will exercise cuDNN. The issue that remains is we also need #174310 for the cuDNN backend fix that is required

@liangel-02
Copy link
Contributor

liangel-02 commented Feb 6, 2026

@eqy oh i meant can we have a test case where we set _should_use_cudnn = True in test_varlen_attention.py? but i can also add that in a separate PR once the bwd is fixed so whatever's easier !

@eqy
Copy link
Collaborator Author

eqy commented Feb 6, 2026

@liangel-02 Did you mean torch/nn/attention/varlen.py
I updated that to

@lru_cache(maxsize=8)
def _should_use_cudnn(device_index: int) -> bool:
    """Cache device capability check to avoid repeated CUDA calls."""
    major_cap = torch.cuda.get_device_capability(device_index)[0]
    if major_cap == 9 or major_cap == 10:
        return True
    return False

Or should this be parametrized in the test as well to sweep both options?

@liangel-02
Copy link
Contributor

@eqy yeah i think it'd be good to parametrize in the test as well so that the tests can run for both backends

pytorchmergebot pushed a commit that referenced this pull request Feb 11, 2026
Currently being tested internally, currently looks OK

also needed for #172108

Pull Request resolved: #174310
Approved by: https://github.com/Skylion007, https://github.com/ngimel, https://github.com/malfet
pytorchmergebot pushed a commit that referenced this pull request Feb 13, 2026
Currently being tested internally, currently looks OK

also needed for #172108

Pull Request resolved: #174310
Approved by: https://github.com/Skylion007, https://github.com/ngimel, https://github.com/malfet
pytorchmergebot pushed a commit that referenced this pull request Feb 17, 2026
Currently being tested internally, currently looks OK

also needed for #172108

Pull Request resolved: #174310
Approved by: https://github.com/Skylion007, https://github.com/ngimel, https://github.com/malfet
eqy added a commit to eqy/pytorch that referenced this pull request Feb 24, 2026
atalman pushed a commit that referenced this pull request Feb 25, 2026
…175672)

* [WINDOWS][cuDNN] Fix cuDNN version mismatch in Windows (#175547)

Authored with claude code
Previous PRs such as #174310 updated cuDNN versions for Linux builds but neglected to do so for Windows.

Claude wrote all of the lintrunner additions for consistency checking
Pull Request resolved: #175547
Approved by: https://github.com/Skylion007, https://github.com/atalman, https://github.com/malfet

* [cuDNN] Upgrade cuDNN to 9.19 for 12.8 and 13.0 wheels (#174310)

Currently being tested internally, currently looks OK

also needed for #172108

Pull Request resolved: #174310
Approved by: https://github.com/Skylion007, https://github.com/ngimel, https://github.com/malfet
@drisspg
Copy link
Contributor

drisspg commented Mar 13, 2026

Did we pull in that version yet?

@eqy
Copy link
Collaborator Author

eqy commented Mar 13, 2026

Not done yet, I'm not sure this is fixed as of 9.20

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: cuda Related to torch.cuda, and CUDA support in general module: cudnn Related to torch.backends.cudnn, and CuDNN support module: sdpa All things related to torch.nn.functional.scaled_dot_product_attentiion open source release notes: cudnn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants