Skip to content

[varlen_attn for inference] add page_table#175924

Closed
liangel-02 wants to merge 21 commits intogh/liangel-02/15/basefrom
gh/liangel-02/15/head
Closed

[varlen_attn for inference] add page_table#175924
liangel-02 wants to merge 21 commits intogh/liangel-02/15/basefrom
gh/liangel-02/15/head

Conversation

@liangel-02
Copy link
Contributor

@liangel-02 liangel-02 commented Feb 26, 2026

page_table is an FA3 feature so we need to modify the function definitions in native_functions.yaml. if this is used with FA2, we throw an error

Stack from ghstack (oldest at bottom):

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Feb 26, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/175924

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (3 Unrelated Failures)

As of commit 0767407 with merge base 4bc9d7f (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following job failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

liangel-02 added a commit that referenced this pull request Feb 26, 2026
ghstack-source-id: b9de853
Pull Request resolved: #175924
@github-actions
Copy link
Contributor

Attention! native_functions.yaml was changed

If you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info.


Caused by:

@github-actions
Copy link
Contributor

Attention! PyTorch one of the C-stable API file was changed

You MUST NOT change existing function declarations in this, as this header defines a stable C ABI. If you need to change the signature for a function, introduce a new v2 version of the function and modify code generation to target the new version of the function.


Caused by:

`page_table` is an FA3 feature so we need to modify the function definitions in `native_functions.yaml`. if this is used with FA2, we throw an error




[ghstack-poisoned]
liangel-02 added a commit that referenced this pull request Feb 26, 2026
ghstack-source-id: 6e2d2ce
Pull Request resolved: #175924
`page_table` is an FA3 feature so we need to modify the function definitions in `native_functions.yaml`. if this is used with FA2, we throw an error




[ghstack-poisoned]
liangel-02 added a commit that referenced this pull request Feb 26, 2026
ghstack-source-id: 19f1d54
Pull Request resolved: #175924
`page_table` is an FA3 feature so we need to modify the function definitions in `native_functions.yaml`. if this is used with FA2, we throw an error




[ghstack-poisoned]
`page_table` is an FA3 feature so we need to modify the function definitions in `native_functions.yaml`. if this is used with FA2, we throw an error




[ghstack-poisoned]
value: torch.Tensor,
cu_seq_q: torch.Tensor,
cu_seq_k: torch.Tensor,
cu_seq_k: torch.Tensor | None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if using page_table, cu_seq_k needs to be None

`page_table` is an FA3 feature so we need to modify the function definitions in `native_functions.yaml`. if this is used with FA2, we throw an error




[ghstack-poisoned]
@liangel-02 liangel-02 requested a review from drisspg February 27, 2026 00:22
@unittest.skipIf("FA3" not in list_flash_attention_impls(), "FA3 not available")
@parametrize("dtype", [torch.bfloat16, torch.float16])
@parametrize("page_size", [32, 64, 128])
@parametrize("compile", [False, True])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adding test for torch.compile

`page_table` is an FA3 feature so we need to modify the function definitions in `native_functions.yaml`. if this is used with FA2, we throw an error




[ghstack-poisoned]
`page_table` is an FA3 feature so we need to modify the function definitions in `native_functions.yaml`. if this is used with FA2, we throw an error




[ghstack-poisoned]
`page_table` is an FA3 feature so we need to modify the function definitions in `native_functions.yaml`. if this is used with FA2, we throw an error




[ghstack-poisoned]
const std::optional<Tensor>& _seqused_k,
const std::optional<Tensor>& _alibi_slopes
const std::optional<Tensor>& _alibi_slopes,
const std::optional<Tensor>& _page_table
Copy link
Contributor

@drisspg drisspg Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OOC fa2 does support this right? https://github.com/Dao-AILab/flash-attention/blob/9a25eba569317708ae295e396aaac0050b28e52b/csrc/flash_attn/flash_api.cpp#L523

perhaps I just didn't end up wiring it or there was a semantic change?

page_table = torch.zeros(
batch_size, max_pages_per_seq, device=device, dtype=torch.int32
)
for i in range(batch_size):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@claude what isa. more efficient way to right this in using more native PyTorch ops

Comment on lines +208 to +209
is larger than the actual sequence. Inference-only (not supported in backward).
page_table (Tensor, optional): Page table mapping logical to physical pages for paged
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should add more description here on the semantics here

@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #176723

pytorchmergebot pushed a commit that referenced this pull request Mar 8, 2026
pytorchmergebot pushed a commit that referenced this pull request Mar 8, 2026
`aten/src/ATen/native/transformers/cuda/attention.cu`

- renamed `_flash_attention_forward` to `_flash_attention_forward_impl`. this is now the core logic and takes `optional<Tensor> out`.
- `_flash_attention_forward` is the non-out variant version and is a thin wrapper that calls `_flash_attention_forward_impl` with `out=std::nullopt`
- `_flash_attention_forward_no_dropout_inplace` is the out-variant and calls `_flash_attention_forward_impl` with `Tensor& out`

`aten/src/ATen/native/native_functions.yaml`

- i registered a new op `_flash_attention_forward_no_dropout_inplace`

`torch/_meta_registrations.py`

- added meta registration that calls `meta__flash_attention_forward` but doesn't return out tensor

`torch/nn/attention/varlen.py`

- added public `varlen_attn_out` and private custom op `_varlen_attn_out` with `mutates_args={"out"}`

`test/test_varlen_attention.py`

- added out variant to existing tests

Pull Request resolved: #176015
Approved by: https://github.com/drisspg
ghstack dependencies: #175897, #175924, #175936
pytorchmergebot pushed a commit that referenced this pull request Mar 8, 2026
@huydhn
Copy link
Contributor

huydhn commented Mar 10, 2026

@pytorchbot revert -m 'Sorry for reverting your change but a bunch of internal builds need to be updated to unblock this change D95758397' -c ghfirst

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request Mar 10, 2026
#176723)"

This reverts commit 26dddb9.

Reverted #176723 on behalf of https://github.com/huydhn due to Sorry for reverting your change but a bunch of internal builds need to be updated to unblock this change D95758397 ([comment](#175924 (comment)))
pytorchmergebot added a commit that referenced this pull request Mar 10, 2026
This reverts commit 492c742.

Reverted #176015 on behalf of https://github.com/huydhn due to Sorry for reverting your change but a bunch of internal builds need to be updated to unblock this change D95758397 ([comment](#175924 (comment)))
pytorchmergebot added a commit that referenced this pull request Mar 10, 2026
This reverts commit 388d61e.

Reverted #175936 on behalf of https://github.com/huydhn due to Sorry for reverting your change but a bunch of internal builds need to be updated to unblock this change D95758397 ([comment](#175924 (comment)))
pytorchmergebot added a commit that referenced this pull request Mar 10, 2026
This reverts commit 9b53dac.

Reverted #175924 on behalf of https://github.com/huydhn due to Sorry for reverting your change but a bunch of internal builds need to be updated to unblock this change D95758397 ([comment](#175924 (comment)))
@pytorchmergebot
Copy link
Collaborator

@liangel-02 your PR has been successfully reverted.

`page_table` is an FA3 feature so we need to modify the function definitions in `native_functions.yaml`. if this is used with FA2, we throw an error




[ghstack-poisoned]
`page_table` is an FA3 feature so we need to modify the function definitions in `native_functions.yaml`. if this is used with FA2, we throw an error




[ghstack-poisoned]
`page_table` is an FA3 feature so we need to modify the function definitions in `native_functions.yaml`. if this is used with FA2, we throw an error




[ghstack-poisoned]
@liangel-02
Copy link
Contributor Author

@liangel-02 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

`page_table` is an FA3 feature so we need to modify the function definitions in `native_functions.yaml`. if this is used with FA2, we throw an error


Differential Revision: [D95996400](https://our.internmc.facebook.com/intern/diff/D95996400)

[ghstack-poisoned]
`page_table` is an FA3 feature so we need to modify the function definitions in `native_functions.yaml`. if this is used with FA2, we throw an error


Differential Revision: [D95996400](https://our.internmc.facebook.com/intern/diff/D95996400)

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #176723

1 similar comment
@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #176723

pytorchmergebot pushed a commit that referenced this pull request Mar 11, 2026
pytorchmergebot pushed a commit that referenced this pull request Mar 11, 2026
`aten/src/ATen/native/transformers/cuda/attention.cu`

- renamed `_flash_attention_forward` to `_flash_attention_forward_impl`. this is now the core logic and takes `optional<Tensor> out`.
- `_flash_attention_forward` is the non-out variant version and is a thin wrapper that calls `_flash_attention_forward_impl` with `out=std::nullopt`
- `_flash_attention_forward_no_dropout_inplace` is the out-variant and calls `_flash_attention_forward_impl` with `Tensor& out`

`aten/src/ATen/native/native_functions.yaml`

- i registered a new op `_flash_attention_forward_no_dropout_inplace`

`torch/_meta_registrations.py`

- added meta registration that calls `meta__flash_attention_forward` but doesn't return out tensor

`torch/nn/attention/varlen.py`

- added public `varlen_attn_out` and private custom op `_varlen_attn_out` with `mutates_args={"out"}`

`test/test_varlen_attention.py`

- added out variant to existing tests

Pull Request resolved: #176015
Approved by: https://github.com/drisspg
ghstack dependencies: #175924, #175936
pytorchmergebot pushed a commit that referenced this pull request Mar 11, 2026
sandy-gags pushed a commit to sandy-gags/pytorch that referenced this pull request Mar 12, 2026
ghstack-source-id: ca17f5c
Pull Request resolved: pytorch/pytorch#175924
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants