Skip to content

Add FA3 to SDPA#172040

Closed
howardzhang-cv wants to merge 12 commits intogh/howardzhang-cv/6/basefrom
gh/howardzhang-cv/6/head
Closed

Add FA3 to SDPA#172040
howardzhang-cv wants to merge 12 commits intogh/howardzhang-cv/6/basefrom
gh/howardzhang-cv/6/head

Conversation

@howardzhang-cv
Copy link
Copy Markdown
Contributor

@howardzhang-cv howardzhang-cv commented Jan 9, 2026

Stack from ghstack (oldest at bottom):

Summary:

Added support for flash attention v3 to SDPA
Only supports fp8 forward pass for now
Created new _fa3.py mirroring _fa4.py registration
modified torch/nn/attention/__init__.py and _registry.py to add FA3 support
Added C++ hook to expose FA3 activation (to allow for fp8 dtype)

Design:

  • The design follows the same basic structure as the FA4 implementation.
  • I added a _fa3.py which mirrors the same registration framework as _fa4.py (wiring up the new fa3 impl with aten ops _flash_attention_forward, _flash_attention_backward, _scaled_dot_product_flash_attention, _scaled_dot_product_flash_attention_backward)
  • I additionally added a C++ hook to expose the FA3 activation to sdp_utils.cpp, and added a new function to check for low precision dtypes in flash attention (since before fp8 was not allowed)
  • Note that only fp8 forward pass is supported for now. We can add fp16/bf16 support later, but fp16/bf16 performance with FA4 is better anyway. The backwards pass currently will throw an error: "FA3 does not support backward pass. Either: 1. Use torch.no_grad() for inference. 2. Unregister FA3 before training: restore_flash_attention_impl"

Test Plan:

Install FA3 library: https://github.com/Dao-AILab/flash-attention/tree/main, follow directions for "FlashAttention-3 beta release", tldr; clone the repo, cd into hopper (where the FA3 implementation lives), python setup.py install

Following the same steps as FA4 to keep it consistent:

activate_flash_attention_impl("FA3")
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
        _ = F.scaled_dot_product_attention(q, k, v)

It also works with torch.compile

activate_flash_attention_impl("FA3")
def sdpa_fn(q, k, v):
      with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
          return F.scaled_dot_product_attention(q, k, v)

compiled_fn = torch.compile(sdpa_fn, fullgraph=True)
_ = compiled_fn(q, k, v)

Some quick runtime results:

I ran some very quick tests with the pytorch profiler to test the runtime of fp8 using FA3 versus bf16 using FA4. Mostly just to catch any easy mistakes slowing it down. These are the results when running on (8, 16, 1024, 128) shape tensors:
FA3 eager:
sdpa_profiler_fa3_eager.txt
FA3 compile:
sdpa_profiler_fa3_compile.txt
FA4 eager:
sdpa_profiler_fa4_eager.txt
FA4 compile:
sdpa_profiler_fa4_compile.txt

Some quick accuracy results:

I ran a very quick SQNR test on (8, 16, 1024, 128) shape tensors 100 times. I got an average SQNR of 25.55.

This is still a draft, so I'll be doing a quick test with the memory next. Then, I'll do some tests on the diffusers library stable diffusion models.

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Jan 9, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/172040

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 84e1f19 with merge base 32642ba (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

howardzhang-cv added a commit that referenced this pull request Jan 9, 2026
Summary: Added support for flash attention v3 to SDPA
Only supports fp8 forward pass for now
Created new _fa3.py mirroring _fa4.py registration
modified torch/nn/attention/__init__.py and _registry.py to add FA3 support
Added C++ hook to expose FA3 activation (to allow for fp8 dtype)

Test Plan:
Mirroring fa4:
Install FA3 library: https://github.com/Dao-AILab/flash-attention/tree/main
activate_flash_attention_impl("FA3")
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
        _ = F.scaled_dot_product_attention(q, k, v)


ghstack-source-id: 5dab717
Pull-Request: #172040
@howardzhang-cv howardzhang-cv marked this pull request as draft January 9, 2026 00:14
[ghstack-poisoned]
howardzhang-cv added a commit that referenced this pull request Jan 9, 2026
Summary: Added support for flash attention v3 to SDPA
Only supports fp8 forward pass for now
Created new _fa3.py mirroring _fa4.py registration
modified torch/nn/attention/__init__.py and _registry.py to add FA3 support
Added C++ hook to expose FA3 activation (to allow for fp8 dtype)

Test Plan:
Mirroring fa4:
Install FA3 library: https://github.com/Dao-AILab/flash-attention/tree/main
activate_flash_attention_impl("FA3")
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
        _ = F.scaled_dot_product_attention(q, k, v)

ghstack-source-id: ac3abec
Pull-Request: #172040
@howardzhang-cv howardzhang-cv added the release notes: nn release notes category label Jan 9, 2026
[ghstack-poisoned]
howardzhang-cv added a commit that referenced this pull request Jan 13, 2026
Summary: Added support for flash attention v3 to SDPA
Only supports fp8 forward pass for now
Created new _fa3.py mirroring _fa4.py registration
modified torch/nn/attention/__init__.py and _registry.py to add FA3 support
Added C++ hook to expose FA3 activation (to allow for fp8 dtype)
Added new descale_q, descale_k, descale_v parameters to public API for SDPA

Test Plan:
Mirroring fa4:
Install FA3 library: https://github.com/Dao-AILab/flash-attention/tree/main
activate_flash_attention_impl("FA3")
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
        _ = F.scaled_dot_product_attention(q, k, v)

ghstack-source-id: cc8fd45
Pull-Request: #172040
@howardzhang-cv
Copy link
Copy Markdown
Contributor Author

I added the new descale arguments to the user-facing API for SDPA. Not sure if this is the best way to do it though since there are a couple ways. This way involves adding new optional arguments, but since the function signature changes, I need to add these arguments to the FA2 and FA4 implementations as well as the XPU and NestedTensor implementations (where they are all unused). If we do want to go this path, should I maybe add some user warning? Right now there is only a user warning in the docstring for SDPA.

Note that this method also required changing AOTInductor fallback_ops.py since it adds a default argument to one of the fallback ops for AOT compiling. Not sure if that's something that we really want to avoid (since it looks like this is the first op to do that?)

[ghstack-poisoned]
howardzhang-cv added a commit that referenced this pull request Jan 14, 2026
Summary: Added support for flash attention v3 to SDPA
Only supports fp8 forward pass for now
Created new _fa3.py mirroring _fa4.py registration
modified torch/nn/attention/__init__.py and _registry.py to add FA3 support
Added C++ hook to expose FA3 activation (to allow for fp8 dtype)
Added new descale_q, descale_k, descale_v parameters to public API for SDPA

Test Plan:
Mirroring fa4:
Install FA3 library: https://github.com/Dao-AILab/flash-attention/tree/main
activate_flash_attention_impl("FA3")
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
        _ = F.scaled_dot_product_attention(q, k, v)

ghstack-source-id: 8d109f6
Pull-Request: #172040
@howardzhang-cv
Copy link
Copy Markdown
Contributor Author

I added the new descale arguments to the user-facing API for SDPA. Not sure if this is the best way to do it though since there are a couple ways. This way involves adding new optional arguments, but since the function signature changes, I need to add these arguments to the FA2 and FA4 implementations as well as the XPU and NestedTensor implementations (where they are all unused). If we do want to go this path, should I maybe add some user warning? Right now there is only a user warning in the docstring for SDPA.

Note that this method also required changing AOTInductor fallback_ops.py since it adds a default argument to one of the fallback ops for AOT compiling. Not sure if that's something that we really want to avoid (since it looks like this is the first op to do that?)

Changed this into a different design using overloaded functions. I removed the descale params from the original public SDPA function and created a new scaled_dot_product_attention_fp8 function which includes them. I overloaded the function with a .low_p in native_functions.yaml. I still needed to add new ops to the fallback_ops.py though. Another side effect of this is that descale params is required now instead of optional (since optional params don't work for overloaded functions, they would count as the same function signature then).

I think this is a bit cleaner since any new params we would need in the future for low precision attention can just be added to the overloaded op. If we didn't do this, we would have to add the argument everywhere (FA2, FA4, XPU, NestedTensor, etc.)

Comment thread tools/autograd/derivatives.yaml Outdated
Comment thread aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
Comment thread torch/nn/attention/_fa3.py Outdated
Comment thread torch/nn/attention/_fa3.py Outdated
Comment thread torch/nn/attention/_fa3.py Outdated
Comment thread torch/nn/attention/_fa3.py
Comment thread torch/nn/attention/_fa3.py Outdated
Comment thread torch/nn/attention/_fa3.py Outdated
Comment thread torch/nn/attention/_fa3.py Outdated
Comment thread torch/nn/functional.py Outdated
)


def scaled_dot_product_attention_fp8(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

talked offline we should have this be private for now or non exposed while we work on API and can add to ao

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the "_" in the beginning and removed it from the functional.pyi.in so it doesn't get caught by wildcard imports and things like that, is this enough, or should we maybe move this into an experimental folder somewhere?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should probably put this into a _prototype like folder in the attention repo and not int he main nn.functional just yet

[ghstack-poisoned]
@howardzhang-cv
Copy link
Copy Markdown
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot Bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 22, 2026
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / linux-jammy-rocm-py3.10 / test (distributed, 3, 3, linux.rocm.gpu.gfx942.4)

Details for Dev Infra team Raised by workflow job

[ghstack-poisoned]
@howardzhang-cv
Copy link
Copy Markdown
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Jan 23, 2026
Summary:
Added meta registration for new scaled_dot_product_flash_attention.low_p overload
Added inductor lowering fallback for new overload
Directly call op overload in _scaled_dot_product_attention_fp8 instead of python builtin function call

Pull Request resolved: #172622
Approved by: https://github.com/drisspg
ghstack dependencies: #172040
suncapitalllc007-star pushed a commit to suncapitalllc007-star/pytorch that referenced this pull request Jan 25, 2026
Summary: Added support for flash attention v3 to SDPA
Only supports fp8 forward pass for now
Created new _fa3.py mirroring _fa4.py registration
modified torch/nn/attention/__init__.py and _registry.py to add FA3 support
Added C++ hook to expose FA3 activation (to allow for fp8 dtype)
Added new descale_q, descale_k, descale_v parameters to public API for SDPA

Test Plan:
Mirroring fa4:
Install FA3 library: https://github.com/Dao-AILab/flash-attention/tree/main
activate_flash_attention_impl("FA3")
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
        _ = torch.nn.attention._scaled_dot_product_attention_quantized._scaled_dot_product_attention_quantized(q, k, v)

ghstack-source-id: 410506e
Pull-Request: pytorch/pytorch#172040
liangel-02 added a commit that referenced this pull request Feb 16, 2026
This PR enables KV caching and paging for `varlen_attn`. Users need to have Flash Attention 3 installed and activated using `activate_flash_attn_impl("FA3")` enabled by #172040.  Note that KV cache and paging functionality is not supported for the backward pass.

**Summary**

Minimal changes were made to the public Python API other than extending the function signature to take in additional positional args `k_cache`, `v_cache`, `cache_seqlens`, `cache_batch_idx`, `page_table`. The private op `_varlen_attn` still calls `torch.ops.aten._flash_attention_forward`.

KV caching/paging is only supported in FA3, but the PyTorch `_flash_attention_forward` uses FA2, so `activate_flash_attn_impl("FA3")` is used to override the dispatcher to redirect to the FA3 implementation. I extended the function signature of `_flash_attention_forward` but it will throw an error if the code ever reaches that path.

**Testing**

Two tests were added. The first one tests the basic functionality and ensures that an error is thrown if KV caching is used without FA3.

The second test is for correctness with various page sizes and sequence lengths. It calls `varlen_attn` twice, once with the `kv _cache` and once without (`combine_cache_and_new` is a helper function that maps the cached tensors w/ the given "new" ones so that we can simulate `varlen_attn` being called without a kv cache). The output from these two calls should be equal.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo




[ghstack-poisoned]
liangel-02 added a commit that referenced this pull request Feb 16, 2026
This PR enables KV caching and paging for `varlen_attn`. Users need to have Flash Attention 3 installed and activated using `activate_flash_attn_impl("FA3")` enabled by #172040.  Note that KV cache and paging functionality is not supported for the backward pass.

**Summary**

Minimal changes were made to the public Python API other than extending the function signature to take in additional positional args `k_cache`, `v_cache`, `cache_seqlens`, `cache_batch_idx`, `page_table`. The private op `_varlen_attn` still calls `torch.ops.aten._flash_attention_forward`.

KV caching/paging is only supported in FA3, but the PyTorch `_flash_attention_forward` uses FA2, so `activate_flash_attn_impl("FA3")` is used to override the dispatcher to redirect to the FA3 implementation. I extended the function signature of `_flash_attention_forward` but it will throw an error if the code ever reaches that path.

**Testing**

Two tests were added. The first one tests the basic functionality and ensures that an error is thrown if KV caching is used without FA3.

The second test is for correctness with various page sizes and sequence lengths. It calls `varlen_attn` twice, once with the `kv _cache` and once without (`combine_cache_and_new` is a helper function that maps the cached tensors w/ the given "new" ones so that we can simulate `varlen_attn` being called without a kv cache). The output from these two calls should be equal.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo




[ghstack-poisoned]
@github-actions github-actions Bot deleted the gh/howardzhang-cv/6/head branch February 22, 2026 02:22
howardzhang-cv added a commit that referenced this pull request May 4, 2026
## Summary
- Added FA4 fp8 implementation using the same FA3 pathway in #172040
- using the torch.ops.aten._scaled_dot_product_flash_attention.quantized overload
- The user interaction path is the same, using _scaled_dot_product_attention_quantized.py. This is meant to only be used with the corresponding torchao API.

## Results
Results for using this path is in pytorch/ao#3947 and pytorch/ao#3960. tldr; speedup depends on sequence length and how much attention takes up the compute of the model, but we get about **1.4x speed-up** on 128k sequence length in a single attention layer. In Llama 3, we see about a **1.23x speed-up** on the entire model runtime with 128k sequence length and a 0.06 increase in perplexity on the WikiText2 dataset

## Important Note
This depends on the Dao-AILab/flash-attention#2109 PR in the flash attention library. That needs to be landed before this path will work.

[ghstack-poisoned]
howardzhang-cv added a commit that referenced this pull request May 4, 2026
## Summary
- Added FA4 fp8 implementation using the same FA3 pathway in #172040
- using the torch.ops.aten._scaled_dot_product_flash_attention.quantized overload
- The user interaction path is the same, using _scaled_dot_product_attention_quantized.py. This is meant to only be used with the corresponding torchao API.

## Results
Results for using this path is in pytorch/ao#3947 and pytorch/ao#3960. tldr; speedup depends on sequence length and how much attention takes up the compute of the model, but we get about **1.4x speed-up** on 128k sequence length in a single attention layer. In Llama 3, we see about a **1.23x speed-up** on the entire model runtime with 128k sequence length and a 0.06 increase in perplexity on the WikiText2 dataset

## Important Note
This depends on the Dao-AILab/flash-attention#2109 PR in the flash attention library. That needs to be landed before this path will work.

[ghstack-poisoned]
howardzhang-cv added a commit that referenced this pull request May 4, 2026
## Summary
- Added FA4 fp8 implementation using the same FA3 pathway in #172040
- using the torch.ops.aten._scaled_dot_product_flash_attention.quantized overload
- The user interaction path is the same, using _scaled_dot_product_attention_quantized.py. This is meant to only be used with the corresponding torchao API.

## Results
Results for using this path is in pytorch/ao#3947 and pytorch/ao#3960. tldr; speedup depends on sequence length and how much attention takes up the compute of the model, but we get about **1.33x speed-up** on 128k sequence length in a single attention layer. In Llama 3, we see about a **1.20x speed-up** on the entire model runtime with 128k sequence length and a 0.06 increase in perplexity on the WikiText2 dataset

[ghstack-poisoned]
howardzhang-cv added a commit that referenced this pull request May 4, 2026
## Summary
- Added FA4 fp8 implementation using the same FA3 pathway in #172040
- using the torch.ops.aten._scaled_dot_product_flash_attention.quantized overload
- The user interaction path is the same, using _scaled_dot_product_attention_quantized.py. This is meant to only be used with the corresponding torchao API.

## Results
Results for using this path is in pytorch/ao#3947 and pytorch/ao#3960. tldr; speedup depends on sequence length and how much attention takes up the compute of the model, but we get about **1.33x speed-up** on 128k sequence length in a single attention layer. In Llama 3, we see about a **1.20x speed-up** on the entire model runtime with 128k sequence length and a 0.06 increase in perplexity on the WikiText2 dataset

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants