Skip to content

[cuDNN][SDPA] cuDNN SDPA refactor/cleanup, nested tensor backward, test priority bump for sm90, sm100#149282

Closed
eqy wants to merge 31 commits intopytorch:mainfrom
eqy:cudnnsdparefactor
Closed

[cuDNN][SDPA] cuDNN SDPA refactor/cleanup, nested tensor backward, test priority bump for sm90, sm100#149282
eqy wants to merge 31 commits intopytorch:mainfrom
eqy:cudnnsdparefactor

Conversation

@eqy
Copy link
Collaborator

@eqy eqy commented Mar 16, 2025

@eqy eqy added open source topic: not user facing topic category module: sdpa All things related to torch.nn.functional.scaled_dot_product_attentiion labels Mar 16, 2025
@eqy eqy requested a review from syed-ahmed as a code owner March 16, 2025 21:09
@pytorch-bot
Copy link

pytorch-bot bot commented Mar 16, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/149282

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit 1a19605 with merge base d7a855d (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

  • pull / linux-jammy-py3_9-clang9-xla / test (xla, 1, 1, linux.12xlarge, unstable) (gh) (#158876)
    /var/lib/jenkins/workspace/xla/torch_xla/csrc/runtime/BUILD:476:14: Compiling torch_xla/csrc/runtime/xla_util_test.cpp failed: (Exit 1): gcc failed: error executing CppCompile command (from target //torch_xla/csrc/runtime:xla_util_test) /usr/bin/gcc -U_FORTIFY_SOURCE -fstack-protector -Wall -Wunused-but-set-parameter -Wno-free-nonheap-object -fno-omit-frame-pointer -g0 -O2 '-D_FORTIFY_SOURCE=1' -DNDEBUG -ffunction-sections ... (remaining 229 arguments skipped)

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@eqy
Copy link
Collaborator Author

eqy commented Mar 16, 2025

@pytorchmergebot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Tried to rebase and push PR #149282, but it was already up to date. Try rebasing against main by issuing:
@pytorchbot rebase -b main

@eqy eqy changed the title [cuDNN][SDPA] cuDNN SDPA refactor/cleanup [WIP][cuDNN][SDPA] cuDNN SDPA refactor/cleanup Mar 17, 2025
@pytorch pytorch deleted a comment from pytorch-bot bot Mar 17, 2025
@eqy
Copy link
Collaborator Author

eqy commented Mar 17, 2025

@pytorchmergebot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased cudnnsdparefactor onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout cudnnsdparefactor && git pull --rebase)

@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Mar 17, 2025

CLA Signed

The committers listed above are authorized under a signed CLA.

Copy link
Collaborator

@Skylion007 Skylion007 Mar 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The update method for mhagraphcache should probably use perfect forward up where the update method is defined instead of an lref. And throughout the file should be to remove extra copies.

Suggested change
mhagraphcache.update(key, mha_graph);
mhagraphcache.update(key, std::move(mha_graph));

@eqy eqy force-pushed the cudnnsdparefactor branch from 3848e20 to bd4432a Compare April 7, 2025 23:31
@eqy eqy requested review from albanD and soulitzer as code owners April 15, 2025 00:44
@eqy eqy changed the title [WIP][cuDNN][SDPA] cuDNN SDPA refactor/cleanup [cuDNN][SDPA] cuDNN SDPA refactor/cleanup, nested tensor backward, test priority bump for sm90, sm100 Apr 28, 2025
@eqy eqy requested review from drisspg and jbschlosser April 28, 2025 21:58
@jerryzh168 jerryzh168 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 29, 2025
@eqy eqy force-pushed the cudnnsdparefactor branch from b6a75a5 to 0baac8a Compare April 30, 2025 18:42
@eqy eqy force-pushed the cudnnsdparefactor branch from c43594b to 0d08279 Compare May 8, 2025 17:44
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we guard consumer GPUS? I guess thats handled in the dispatch

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe do a NB: sdpa api is transposed vs cudnn

Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any benchmarks?

@eqy eqy force-pushed the cudnnsdparefactor branch from 74069f9 to 1a19605 Compare August 6, 2025 18:51
@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this in D79744374.

@facebook-github-bot
Copy link
Contributor

@pytorchbot merge

(Initiating merge automatically since Phabricator Diff has merged)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: inductor / linux-jammy-cpu-py3.9-gcc11-inductor / build

Details for Dev Infra team Raised by workflow job

@izaitsevfb
Copy link
Contributor

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 4 checks: pull / linux-jammy-py3_9-clang9-xla / test (xla, 1, 1, linux.12xlarge, unstable), inductor / linux-jammy-cpu-py3.9-gcc11-inductor / build, trunk / linux-jammy-rocm-py3.10 / test (distributed, 1, 1, linux.rocm.gpu.4), Meta Internal-Only Changes Check

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

hinriksnaer pushed a commit to hinriksnaer/pytorch that referenced this pull request Aug 8, 2025
…st priority bump for `sm90`, `sm100` (pytorch#149282)

cleanup tuple/tensor boilerplate in cuDNN SDPA, preparation for nested/ragged tensor backward

Pull Request resolved: pytorch#149282
Approved by: https://github.com/drisspg

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
pytorchmergebot pushed a commit that referenced this pull request Aug 15, 2025
Opt-in for now, but basically uses the variable-sequence length/ragged path for the common case of BSHD layout to avoid recompiling for different sequence lengths.

Built on top of #149282

Tested using a primitive fuzzer, seems at least as stable as default path (with recompilation) on B200 (50000+ cases tested without any failures)

Pull Request resolved: #155958
Approved by: https://github.com/drisspg
can-gaa-hou pushed a commit to can-gaa-hou/pytorch that referenced this pull request Aug 22, 2025
…#155958)

Opt-in for now, but basically uses the variable-sequence length/ragged path for the common case of BSHD layout to avoid recompiling for different sequence lengths.

Built on top of pytorch#149282

Tested using a primitive fuzzer, seems at least as stable as default path (with recompilation) on B200 (50000+ cases tested without any failures)

Pull Request resolved: pytorch#155958
Approved by: https://github.com/drisspg
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…st priority bump for `sm90`, `sm100` (pytorch#149282)

cleanup tuple/tensor boilerplate in cuDNN SDPA, preparation for nested/ragged tensor backward

Pull Request resolved: pytorch#149282
Approved by: https://github.com/drisspg

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…#155958)

Opt-in for now, but basically uses the variable-sequence length/ragged path for the common case of BSHD layout to avoid recompiling for different sequence lengths.

Built on top of pytorch#149282

Tested using a primitive fuzzer, seems at least as stable as default path (with recompilation) on B200 (50000+ cases tested without any failures)

Pull Request resolved: pytorch#155958
Approved by: https://github.com/drisspg
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor module: sdpa All things related to torch.nn.functional.scaled_dot_product_attentiion open source Reverted topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants