[WIP][cuDNN][cuDNN V8 API] Add experimental cuDNN MHA/Flash Attention support#101916
[WIP][cuDNN][cuDNN V8 API] Add experimental cuDNN MHA/Flash Attention support#101916eqy wants to merge 8 commits intopytorch:mainfrom
Conversation
build wip debug checkin debug wip some caching and philox seed stuff caching boilerplate broken wip fix refactor a bit check in cleanup
|
Do you have any benchmarks of this compared to any of the other implementations ? |
|
@eqy was also curious if you managed to get any benchmarks for this work? |
|
Sorry, haven't had a chance to benchmark this due to other cuDNN issues being prioritized at the moment, but the expectation is that it is not expected to offer higher performance for now. |
|
@eqy no worries, I saw that the cuddn implementation was mentioned here: https://developer.nvidia.com/blog/breaking-mlperf-training-records-with-nvidia-h100-gpus/ and was very curious how it compares |
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Initial implementation of forward pass cuDNN Flash Attention; current major restrictions are:
Gated by
TORCH_CUDNN_MHA_ENABLED=1environment variable.The plan is to eventually avoid pattern matching against strides as the support matrix from the cuDNN side is improved.
CC @ngimel @ptrblck
cc @csarofeen @ptrblck @xwang233 @ngimel