Add FA3 to SDPA#172040
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/172040
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 84e1f19 with merge base 32642ba ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Summary: Added support for flash attention v3 to SDPA Only supports fp8 forward pass for now Created new _fa3.py mirroring _fa4.py registration modified torch/nn/attention/__init__.py and _registry.py to add FA3 support Added C++ hook to expose FA3 activation (to allow for fp8 dtype) Test Plan: Mirroring fa4: Install FA3 library: https://github.com/Dao-AILab/flash-attention/tree/main activate_flash_attention_impl("FA3") with sdpa_kernel(SDPBackend.FLASH_ATTENTION): _ = F.scaled_dot_product_attention(q, k, v) ghstack-source-id: 5dab717 Pull-Request: #172040
Summary: Added support for flash attention v3 to SDPA Only supports fp8 forward pass for now Created new _fa3.py mirroring _fa4.py registration modified torch/nn/attention/__init__.py and _registry.py to add FA3 support Added C++ hook to expose FA3 activation (to allow for fp8 dtype) Test Plan: Mirroring fa4: Install FA3 library: https://github.com/Dao-AILab/flash-attention/tree/main activate_flash_attention_impl("FA3") with sdpa_kernel(SDPBackend.FLASH_ATTENTION): _ = F.scaled_dot_product_attention(q, k, v) ghstack-source-id: ac3abec Pull-Request: #172040
Summary: Added support for flash attention v3 to SDPA Only supports fp8 forward pass for now Created new _fa3.py mirroring _fa4.py registration modified torch/nn/attention/__init__.py and _registry.py to add FA3 support Added C++ hook to expose FA3 activation (to allow for fp8 dtype) Added new descale_q, descale_k, descale_v parameters to public API for SDPA Test Plan: Mirroring fa4: Install FA3 library: https://github.com/Dao-AILab/flash-attention/tree/main activate_flash_attention_impl("FA3") with sdpa_kernel(SDPBackend.FLASH_ATTENTION): _ = F.scaled_dot_product_attention(q, k, v) ghstack-source-id: cc8fd45 Pull-Request: #172040
|
I added the new descale arguments to the user-facing API for SDPA. Not sure if this is the best way to do it though since there are a couple ways. This way involves adding new optional arguments, but since the function signature changes, I need to add these arguments to the FA2 and FA4 implementations as well as the XPU and NestedTensor implementations (where they are all unused). If we do want to go this path, should I maybe add some user warning? Right now there is only a user warning in the docstring for SDPA. Note that this method also required changing AOTInductor fallback_ops.py since it adds a default argument to one of the fallback ops for AOT compiling. Not sure if that's something that we really want to avoid (since it looks like this is the first op to do that?) |
Summary: Added support for flash attention v3 to SDPA Only supports fp8 forward pass for now Created new _fa3.py mirroring _fa4.py registration modified torch/nn/attention/__init__.py and _registry.py to add FA3 support Added C++ hook to expose FA3 activation (to allow for fp8 dtype) Added new descale_q, descale_k, descale_v parameters to public API for SDPA Test Plan: Mirroring fa4: Install FA3 library: https://github.com/Dao-AILab/flash-attention/tree/main activate_flash_attention_impl("FA3") with sdpa_kernel(SDPBackend.FLASH_ATTENTION): _ = F.scaled_dot_product_attention(q, k, v) ghstack-source-id: 8d109f6 Pull-Request: #172040
Changed this into a different design using overloaded functions. I removed the descale params from the original public SDPA function and created a new scaled_dot_product_attention_fp8 function which includes them. I overloaded the function with a .low_p in native_functions.yaml. I still needed to add new ops to the fallback_ops.py though. Another side effect of this is that descale params is required now instead of optional (since optional params don't work for overloaded functions, they would count as the same function signature then). I think this is a bit cleaner since any new params we would need in the future for low precision attention can just be added to the overloaded op. If we didn't do this, we would have to add the argument everywhere (FA2, FA4, XPU, NestedTensor, etc.) |
| ) | ||
|
|
||
|
|
||
| def scaled_dot_product_attention_fp8( |
There was a problem hiding this comment.
talked offline we should have this be private for now or non exposed while we work on API and can add to ao
There was a problem hiding this comment.
I added the "_" in the beginning and removed it from the functional.pyi.in so it doesn't get caught by wildcard imports and things like that, is this enough, or should we maybe move this into an experimental folder somewhere?
There was a problem hiding this comment.
I think we should probably put this into a _prototype like folder in the attention repo and not int he main nn.functional just yet
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / linux-jammy-rocm-py3.10 / test (distributed, 3, 3, linux.rocm.gpu.gfx942.4) Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Summary: Added meta registration for new scaled_dot_product_flash_attention.low_p overload Added inductor lowering fallback for new overload Directly call op overload in _scaled_dot_product_attention_fp8 instead of python builtin function call Pull Request resolved: #172622 Approved by: https://github.com/drisspg ghstack dependencies: #172040
Summary: Added support for flash attention v3 to SDPA Only supports fp8 forward pass for now Created new _fa3.py mirroring _fa4.py registration modified torch/nn/attention/__init__.py and _registry.py to add FA3 support Added C++ hook to expose FA3 activation (to allow for fp8 dtype) Added new descale_q, descale_k, descale_v parameters to public API for SDPA Test Plan: Mirroring fa4: Install FA3 library: https://github.com/Dao-AILab/flash-attention/tree/main activate_flash_attention_impl("FA3") with sdpa_kernel(SDPBackend.FLASH_ATTENTION): _ = torch.nn.attention._scaled_dot_product_attention_quantized._scaled_dot_product_attention_quantized(q, k, v) ghstack-source-id: 410506e Pull-Request: pytorch/pytorch#172040
This PR enables KV caching and paging for `varlen_attn`. Users need to have Flash Attention 3 installed and activated using `activate_flash_attn_impl("FA3")` enabled by #172040. Note that KV cache and paging functionality is not supported for the backward pass.
**Summary**
Minimal changes were made to the public Python API other than extending the function signature to take in additional positional args `k_cache`, `v_cache`, `cache_seqlens`, `cache_batch_idx`, `page_table`. The private op `_varlen_attn` still calls `torch.ops.aten._flash_attention_forward`.
KV caching/paging is only supported in FA3, but the PyTorch `_flash_attention_forward` uses FA2, so `activate_flash_attn_impl("FA3")` is used to override the dispatcher to redirect to the FA3 implementation. I extended the function signature of `_flash_attention_forward` but it will throw an error if the code ever reaches that path.
**Testing**
Two tests were added. The first one tests the basic functionality and ensures that an error is thrown if KV caching is used without FA3.
The second test is for correctness with various page sizes and sequence lengths. It calls `varlen_attn` twice, once with the `kv _cache` and once without (`combine_cache_and_new` is a helper function that maps the cached tensors w/ the given "new" ones so that we can simulate `varlen_attn` being called without a kv cache). The output from these two calls should be equal.
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo
[ghstack-poisoned]
This PR enables KV caching and paging for `varlen_attn`. Users need to have Flash Attention 3 installed and activated using `activate_flash_attn_impl("FA3")` enabled by #172040. Note that KV cache and paging functionality is not supported for the backward pass.
**Summary**
Minimal changes were made to the public Python API other than extending the function signature to take in additional positional args `k_cache`, `v_cache`, `cache_seqlens`, `cache_batch_idx`, `page_table`. The private op `_varlen_attn` still calls `torch.ops.aten._flash_attention_forward`.
KV caching/paging is only supported in FA3, but the PyTorch `_flash_attention_forward` uses FA2, so `activate_flash_attn_impl("FA3")` is used to override the dispatcher to redirect to the FA3 implementation. I extended the function signature of `_flash_attention_forward` but it will throw an error if the code ever reaches that path.
**Testing**
Two tests were added. The first one tests the basic functionality and ensures that an error is thrown if KV caching is used without FA3.
The second test is for correctness with various page sizes and sequence lengths. It calls `varlen_attn` twice, once with the `kv _cache` and once without (`combine_cache_and_new` is a helper function that maps the cached tensors w/ the given "new" ones so that we can simulate `varlen_attn` being called without a kv cache). The output from these two calls should be equal.
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo
[ghstack-poisoned]
## Summary - Added FA4 fp8 implementation using the same FA3 pathway in #172040 - using the torch.ops.aten._scaled_dot_product_flash_attention.quantized overload - The user interaction path is the same, using _scaled_dot_product_attention_quantized.py. This is meant to only be used with the corresponding torchao API. ## Results Results for using this path is in pytorch/ao#3947 and pytorch/ao#3960. tldr; speedup depends on sequence length and how much attention takes up the compute of the model, but we get about **1.4x speed-up** on 128k sequence length in a single attention layer. In Llama 3, we see about a **1.23x speed-up** on the entire model runtime with 128k sequence length and a 0.06 increase in perplexity on the WikiText2 dataset ## Important Note This depends on the Dao-AILab/flash-attention#2109 PR in the flash attention library. That needs to be landed before this path will work. [ghstack-poisoned]
## Summary - Added FA4 fp8 implementation using the same FA3 pathway in #172040 - using the torch.ops.aten._scaled_dot_product_flash_attention.quantized overload - The user interaction path is the same, using _scaled_dot_product_attention_quantized.py. This is meant to only be used with the corresponding torchao API. ## Results Results for using this path is in pytorch/ao#3947 and pytorch/ao#3960. tldr; speedup depends on sequence length and how much attention takes up the compute of the model, but we get about **1.4x speed-up** on 128k sequence length in a single attention layer. In Llama 3, we see about a **1.23x speed-up** on the entire model runtime with 128k sequence length and a 0.06 increase in perplexity on the WikiText2 dataset ## Important Note This depends on the Dao-AILab/flash-attention#2109 PR in the flash attention library. That needs to be landed before this path will work. [ghstack-poisoned]
## Summary - Added FA4 fp8 implementation using the same FA3 pathway in #172040 - using the torch.ops.aten._scaled_dot_product_flash_attention.quantized overload - The user interaction path is the same, using _scaled_dot_product_attention_quantized.py. This is meant to only be used with the corresponding torchao API. ## Results Results for using this path is in pytorch/ao#3947 and pytorch/ao#3960. tldr; speedup depends on sequence length and how much attention takes up the compute of the model, but we get about **1.33x speed-up** on 128k sequence length in a single attention layer. In Llama 3, we see about a **1.20x speed-up** on the entire model runtime with 128k sequence length and a 0.06 increase in perplexity on the WikiText2 dataset [ghstack-poisoned]
## Summary - Added FA4 fp8 implementation using the same FA3 pathway in #172040 - using the torch.ops.aten._scaled_dot_product_flash_attention.quantized overload - The user interaction path is the same, using _scaled_dot_product_attention_quantized.py. This is meant to only be used with the corresponding torchao API. ## Results Results for using this path is in pytorch/ao#3947 and pytorch/ao#3960. tldr; speedup depends on sequence length and how much attention takes up the compute of the model, but we get about **1.33x speed-up** on 128k sequence length in a single attention layer. In Llama 3, we see about a **1.20x speed-up** on the entire model runtime with 128k sequence length and a 0.06 increase in perplexity on the WikiText2 dataset [ghstack-poisoned]
Stack from ghstack (oldest at bottom):
Summary:
Added support for flash attention v3 to SDPA
Only supports fp8 forward pass for now
Created new
_fa3.pymirroring_fa4.pyregistrationmodified
torch/nn/attention/__init__.pyand_registry.pyto add FA3 supportAdded C++ hook to expose FA3 activation (to allow for fp8 dtype)
Design:
_fa3.pywhich mirrors the same registration framework as_fa4.py(wiring up the new fa3 impl with aten ops_flash_attention_forward,_flash_attention_backward,_scaled_dot_product_flash_attention,_scaled_dot_product_flash_attention_backward)sdp_utils.cpp, and added a new function to check for low precision dtypes in flash attention (since before fp8 was not allowed)restore_flash_attention_impl"Test Plan:
Install FA3 library: https://github.com/Dao-AILab/flash-attention/tree/main, follow directions for "FlashAttention-3 beta release", tldr; clone the repo, cd into hopper (where the FA3 implementation lives), python setup.py install
Following the same steps as FA4 to keep it consistent:
It also works with torch.compile
Some quick runtime results:
I ran some very quick tests with the pytorch profiler to test the runtime of fp8 using FA3 versus bf16 using FA4. Mostly just to catch any easy mistakes slowing it down. These are the results when running on (8, 16, 1024, 128) shape tensors:
FA3 eager:
sdpa_profiler_fa3_eager.txt
FA3 compile:
sdpa_profiler_fa3_compile.txt
FA4 eager:
sdpa_profiler_fa4_eager.txt
FA4 compile:
sdpa_profiler_fa4_compile.txt
Some quick accuracy results:
I ran a very quick SQNR test on (8, 16, 1024, 128) shape tensors 100 times. I got an average SQNR of 25.55.
This is still a draft, so I'll be doing a quick test with the memory next. Then, I'll do some tests on the diffusers library stable diffusion models.