Skip to content

[WIP][cuDNN][cuDNN V8 API] Add experimental cuDNN MHA/Flash Attention support#101916

Closed
eqy wants to merge 8 commits intopytorch:mainfrom
eqy:cudnnmha3
Closed

[WIP][cuDNN][cuDNN V8 API] Add experimental cuDNN MHA/Flash Attention support#101916
eqy wants to merge 8 commits intopytorch:mainfrom
eqy:cudnnmha3

Conversation

@eqy
Copy link
Copy Markdown
Collaborator

@eqy eqy commented May 19, 2023

Initial implementation of forward pass cuDNN Flash Attention; current major restrictions are:

  • Only packed data layout (e.g., QKV tensors as chunks of the same tensor) supported
  • Only SM 9.0 and SM 8.0 support
  • Only head dim and sequence length divisible by 64 supported

Gated by TORCH_CUDNN_MHA_ENABLED=1 environment variable.
The plan is to eventually avoid pattern matching against strides as the support matrix from the cuDNN side is improved.

CC @ngimel @ptrblck

cc @csarofeen @ptrblck @xwang233 @ngimel

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented May 19, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/101916

Note: Links to docs will display an error until the docs builds have been completed.

❌ 18 New Failures, 2 Pending

As of commit 32a8ef8 with merge base 38e73b3 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@eqy eqy added module: cudnn Related to torch.backends.cudnn, and CuDNN support module: cuda Related to torch.cuda, and CUDA support in general labels May 19, 2023
eqy added 8 commits May 23, 2023 01:11
build

wip

debug checkin

debug wip

some caching and philox seed stuff

caching boilerplate

broken wip

fix

refactor a bit check in

cleanup
@eellison
Copy link
Copy Markdown
Contributor

Do you have any benchmarks of this compared to any of the other implementations ?

@ngimel ngimel requested a review from drisspg June 12, 2023 20:13
@drisspg
Copy link
Copy Markdown
Contributor

drisspg commented Jun 30, 2023

@eqy was also curious if you managed to get any benchmarks for this work?

@eqy
Copy link
Copy Markdown
Collaborator Author

eqy commented Jun 30, 2023

Sorry, haven't had a chance to benchmark this due to other cuDNN issues being prioritized at the moment, but the expectation is that it is not expected to offer higher performance for now.

@drisspg
Copy link
Copy Markdown
Contributor

drisspg commented Jun 30, 2023

@eqy no worries, I saw that the cuddn implementation was mentioned here: https://developer.nvidia.com/blog/breaking-mlperf-training-records-with-nvidia-h100-gpus/ and was very curious how it compares

@github-actions
Copy link
Copy Markdown
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: cuda Related to torch.cuda, and CUDA support in general module: cudnn Related to torch.backends.cudnn, and CuDNN support open source Stale topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants