[cuDNN][varlen] Fixes for cuDNN varlen SDPA#172108
[cuDNN][varlen] Fixes for cuDNN varlen SDPA#172108eqy wants to merge 11 commits intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/172108
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit f21f5be with merge base bfed04b ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
go for it :) |
|
@pytorchmergebot rebase |
|
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
|
Successfully rebased |
69784d3 to
067383e
Compare
test/test_transformers.py
Outdated
| @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system") | ||
| @unittest.skipIf(torch.backends.cudnn.version() is None or torch.backends.cudnn.version() < 91800, "cuDNN 9.18.0.64 is needed for correct support") | ||
| @parametrize("dtype", [torch.bfloat16, torch.half]) | ||
| def test_cudnn_attention_varlen(self, dtype): |
There was a problem hiding this comment.
can move this test to test/test_varlen_attention.py
test/test_transformers.py
Outdated
| from typing import Any, Optional | ||
|
|
||
| @torch.library.custom_op("torch_attn::_varlen_attn", mutates_args={}) | ||
| def _varlen_attn( |
There was a problem hiding this comment.
i think we can test by setting _should_use_cudnn to True in pytorch/torch/nn/attention/varlen.py and modifying test/test_varlen_attention.py to avoid duplication
There was a problem hiding this comment.
sorry to clarify, i meant we can set _should_use_cudnn = True in the test (i wrote the wrong file lol)
test/test_transformers.py
Outdated
|
|
||
| _varlen_attn.register_autograd(_backward, setup_context=_setup_context) | ||
|
|
||
| VarlenShape = namedtuple( |
There was a problem hiding this comment.
This named tuple way of doing things is deprecated. Can you use the subclass API instead that preserves typing info
|
Sure just FYI need to revisit this after we upgrade to cuDNN 9.18+, as otherwise these tests are expected to fail |
| auto rag_k_off = cum_seqlen_kv.mul(k.stride(-3)); | ||
| auto rag_v_off = cum_seqlen_kv.mul(v.stride(-3)); | ||
| auto rag_o_off = cum_seqlen_q.mul(o.stride(-3)); | ||
| auto rag_dq_off = cum_seqlen_q.mul(dQ.stride(-3)); |
There was a problem hiding this comment.
@claude can you do a review of this PR? Also it looks like we pass through strides for all tensors but dO is there a possibility the stride info for dO is wrong or are we guaranteed it will match q based off of construction?
|
@claude can you do a review of this PR? Also it looks like we pass through strides for all tensors but dO is there a possibility the stride info for dO is wrong or are we guaranteed it will match q based off of construction? |
|
Claude finished @drisspg's task in 1m 25s —— View job PR Review: cuDNN varlen SDPA Fixes
SummaryThis PR fixes several issues with the cuDNN variable-length (varlen) SDPA backward pass:
Detailed Review
|
|
can we still enable the tests in this PR in |
|
@liangel-02 sorry if I am misunderstanding the question but I thought this PR is enabling the tests in that it will exercise cuDNN. The issue that remains is we also need #174310 for the cuDNN backend fix that is required |
|
@eqy oh i meant can we have a test case where we set |
|
@liangel-02 Did you mean torch/nn/attention/varlen.py Or should this be parametrized in the test as well to sweep both options? |
|
@eqy yeah i think it'd be good to parametrize in the test as well so that the tests can run for both backends |
Currently being tested internally, currently looks OK also needed for #172108 Pull Request resolved: #174310 Approved by: https://github.com/Skylion007, https://github.com/ngimel, https://github.com/malfet
Currently being tested internally, currently looks OK also needed for #172108 Pull Request resolved: #174310 Approved by: https://github.com/Skylion007, https://github.com/ngimel, https://github.com/malfet
Currently being tested internally, currently looks OK also needed for #172108 Pull Request resolved: #174310 Approved by: https://github.com/Skylion007, https://github.com/ngimel, https://github.com/malfet
Currently being tested internally, currently looks OK also needed for pytorch#172108 Pull Request resolved: pytorch#174310 Approved by: https://github.com/Skylion007, https://github.com/ngimel, https://github.com/malfet
…175672) * [WINDOWS][cuDNN] Fix cuDNN version mismatch in Windows (#175547) Authored with claude code Previous PRs such as #174310 updated cuDNN versions for Linux builds but neglected to do so for Windows. Claude wrote all of the lintrunner additions for consistency checking Pull Request resolved: #175547 Approved by: https://github.com/Skylion007, https://github.com/atalman, https://github.com/malfet * [cuDNN] Upgrade cuDNN to 9.19 for 12.8 and 13.0 wheels (#174310) Currently being tested internally, currently looks OK also needed for #172108 Pull Request resolved: #174310 Approved by: https://github.com/Skylion007, https://github.com/ngimel, https://github.com/malfet
|
Did we pull in that version yet? |
|
Not done yet, I'm not sure this is fixed as of 9.20 |
Also requires a newer cuDNN backend version so wouldn't pass tests yet
@liangel-02 Would it be OK to use the repro you shared as a test here or is that still considered private?
cc @csarofeen @ptrblck @xwang233 @nWEIdia @msaroufim @jerryzh168 @tinglvv