[CUDNN][SDPA] Experimental cuDNN Flash Attention v2 Inference#115663
[CUDNN][SDPA] Experimental cuDNN Flash Attention v2 Inference#115663eqy wants to merge 52 commits intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/115663
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 3 Unrelated FailuresAs of commit 4fc3337 with merge base 3ab0894 ( NEW FAILURE - The following job has failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Also feel free to ping me when you think I should do a review |
|
@pytorchmergebot rebase |
|
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
|
Successfully rebased |
fc7cf93 to
7a1e8d9
Compare
|
@drisspg sorry for the delay, I think this should be ready for review now |
|
|
||
| } // namespace c10 | ||
|
|
||
| #define AT_CUDNN_FRONTEND_CHECK(EXPR, ...) \ |
There was a problem hiding this comment.
Is this a result of frontend changing to 9.0?
There was a problem hiding this comment.
This is kind of just the "new" way that cuDNN-frontend has been doing error reporting and this is done in recent frontend versions that are supported in 8.9.x. It is a bit annoying in that there is now a mechanism for getting an error string and it is different from the previous convention of "return CUDNN_STATUS_SUCCESS."
aten/src/ATen/native/cudnn/MHA.cpp
Outdated
| // .set_dim({b, 1, s_q, s_kv}) | ||
| // .set_stride({s_q * s_kv, s_q * s_kv, s_kv, 1})); | ||
| auto seed = mha_graph->tensor(fe::graph::Tensor_attributes() | ||
| .set_name("Seed") |
There was a problem hiding this comment.
Another dumb Q: are these names meaningful? like do they need to match an api for calling the function?
There was a problem hiding this comment.
Not a dumb question! The names are potentially meaningful in the future in that we would be able to reference I/O tensors by name rather than by holding onto an explicit reference as we currently do to supply the "variant pack" (the map used to specify the data pointers for each I/O tensor upon invocation). In either case they do not really mean anything to cuDNN and rather are for the user's (us) housekeeping.
aten/src/ATen/native/cudnn/MHA.cpp
Outdated
| if (cudnnGetVersion() >= 8904) { | ||
| //scaled_dot_product_flash_attention_options.set_alibi_mask(true); | ||
| } | ||
|
|
There was a problem hiding this comment.
Does this only support certain bias types?
There was a problem hiding this comment.
I would need check with the cuDNN folks, this doc https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#flash-fused-multi-head-att-fprop doesn't seem to be very clear about this
aten/src/ATen/native/cudnn/MHA.cpp
Outdated
| cudnnHandle_t handle = getCudnnHandle(); | ||
| o = at::empty_strided({b, h, s_q, d}, {s_q * h * d, d, h * d, 1}, q.options()); | ||
| if (return_softmaxstats) { | ||
| // TODO(eqy): fix strides |
There was a problem hiding this comment.
I assume this stride fix is only needed for backward support right? so is okay for this PR?
There was a problem hiding this comment.
I think this comment is out-of-date, but I should verify rather than fix the strides here
|
One other random question, if I wanted to play around with this early, is there any specific build options I should be using? And any specific version of cudnn I need to have? |
I don't think you would need any specific build options---cuDNN >= 8.9 would probably work best. The only other requirement should be cuDNN frontend 1.0, but that tag was updated in the |
|
@pytorchmergebot rebase |
|
Successfully rebased |
153620a to
4fc3337
Compare
|
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
|
@pytorchbot merge -i Landed internally as D53716382 |
Merge startedYour change will be merged while ignoring the following 4 checks: periodic / buck-build-test / buck-build-test (default, 1, 1, ubuntu-latest), periodic / win-vs2019-cuda11.8-py3 / test (default, 2, 4, windows.g5.4xlarge.nvidia.gpu), periodic / linux-focal-rocm5.7-py3.8 / test (distributed, 2, 2, linux.rocm.gpu), periodic / linux-focal-cuda11.8-py3.9-gcc9 / test (multigpu, 1, 1, linux.g5.12xlarge.nvidia.gpu) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
We should probably update our cudnn binaries as most of the recent changes has been improving this flash attention kernal. |
|
I would like to ask a question, and I hope it doesn't cause any inconvenience: |
yes, although mainly because the graph rebuild would incur JIT compilation in addition to some CPU overhead |
Thank you very much for your prompt response! May I ask the reason why cache is not used in |
#113713
Going to clean up some of the checks and will remove draft status after.
Can be tested on SM80+ with
TORCH_CUDNN_MHA_ENABLED=1.CC @drisspg @ptrblck