[varlen_attn for inference] add page_table#175924
[varlen_attn for inference] add page_table#175924liangel-02 wants to merge 21 commits intogh/liangel-02/15/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/175924
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (3 Unrelated Failures)As of commit 0767407 with merge base 4bc9d7f ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
BROKEN TRUNK - The following job failed but was present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Attention! native_functions.yaml was changedIf you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info. Caused by: |
Attention! PyTorch one of the C-stable API file was changedYou MUST NOT change existing function declarations in this, as this header defines a stable C ABI. If you need to change the signature for a function, introduce a new v2 version of the function and modify code generation to target the new version of the function. Caused by: |
`page_table` is an FA3 feature so we need to modify the function definitions in `native_functions.yaml`. if this is used with FA2, we throw an error [ghstack-poisoned]
`page_table` is an FA3 feature so we need to modify the function definitions in `native_functions.yaml`. if this is used with FA2, we throw an error [ghstack-poisoned]
`page_table` is an FA3 feature so we need to modify the function definitions in `native_functions.yaml`. if this is used with FA2, we throw an error [ghstack-poisoned]
`page_table` is an FA3 feature so we need to modify the function definitions in `native_functions.yaml`. if this is used with FA2, we throw an error [ghstack-poisoned]
| value: torch.Tensor, | ||
| cu_seq_q: torch.Tensor, | ||
| cu_seq_k: torch.Tensor, | ||
| cu_seq_k: torch.Tensor | None, |
There was a problem hiding this comment.
if using page_table, cu_seq_k needs to be None
`page_table` is an FA3 feature so we need to modify the function definitions in `native_functions.yaml`. if this is used with FA2, we throw an error [ghstack-poisoned]
| @unittest.skipIf("FA3" not in list_flash_attention_impls(), "FA3 not available") | ||
| @parametrize("dtype", [torch.bfloat16, torch.float16]) | ||
| @parametrize("page_size", [32, 64, 128]) | ||
| @parametrize("compile", [False, True]) |
There was a problem hiding this comment.
adding test for torch.compile
`page_table` is an FA3 feature so we need to modify the function definitions in `native_functions.yaml`. if this is used with FA2, we throw an error [ghstack-poisoned]
`page_table` is an FA3 feature so we need to modify the function definitions in `native_functions.yaml`. if this is used with FA2, we throw an error [ghstack-poisoned]
`page_table` is an FA3 feature so we need to modify the function definitions in `native_functions.yaml`. if this is used with FA2, we throw an error [ghstack-poisoned]
| const std::optional<Tensor>& _seqused_k, | ||
| const std::optional<Tensor>& _alibi_slopes | ||
| const std::optional<Tensor>& _alibi_slopes, | ||
| const std::optional<Tensor>& _page_table |
There was a problem hiding this comment.
OOC fa2 does support this right? https://github.com/Dao-AILab/flash-attention/blob/9a25eba569317708ae295e396aaac0050b28e52b/csrc/flash_attn/flash_api.cpp#L523
perhaps I just didn't end up wiring it or there was a semantic change?
test/test_varlen_attention.py
Outdated
| page_table = torch.zeros( | ||
| batch_size, max_pages_per_seq, device=device, dtype=torch.int32 | ||
| ) | ||
| for i in range(batch_size): |
There was a problem hiding this comment.
@claude what isa. more efficient way to right this in using more native PyTorch ops
torch/nn/attention/varlen.py
Outdated
| is larger than the actual sequence. Inference-only (not supported in backward). | ||
| page_table (Tensor, optional): Page table mapping logical to physical pages for paged |
There was a problem hiding this comment.
we should add more description here on the semantics here
|
Starting merge as part of PR stack under #176723 |
Pull Request resolved: #175936 Approved by: https://github.com/drisspg ghstack dependencies: #175897, #175924
`aten/src/ATen/native/transformers/cuda/attention.cu`
- renamed `_flash_attention_forward` to `_flash_attention_forward_impl`. this is now the core logic and takes `optional<Tensor> out`.
- `_flash_attention_forward` is the non-out variant version and is a thin wrapper that calls `_flash_attention_forward_impl` with `out=std::nullopt`
- `_flash_attention_forward_no_dropout_inplace` is the out-variant and calls `_flash_attention_forward_impl` with `Tensor& out`
`aten/src/ATen/native/native_functions.yaml`
- i registered a new op `_flash_attention_forward_no_dropout_inplace`
`torch/_meta_registrations.py`
- added meta registration that calls `meta__flash_attention_forward` but doesn't return out tensor
`torch/nn/attention/varlen.py`
- added public `varlen_attn_out` and private custom op `_varlen_attn_out` with `mutates_args={"out"}`
`test/test_varlen_attention.py`
- added out variant to existing tests
Pull Request resolved: #176015
Approved by: https://github.com/drisspg
ghstack dependencies: #175897, #175924, #175936
|
@pytorchbot revert -m 'Sorry for reverting your change but a bunch of internal builds need to be updated to unblock this change D95758397' -c ghfirst |
|
@pytorchbot successfully started a revert job. Check the current status here. |
#176723)" This reverts commit 26dddb9. Reverted #176723 on behalf of https://github.com/huydhn due to Sorry for reverting your change but a bunch of internal builds need to be updated to unblock this change D95758397 ([comment](#175924 (comment)))
This reverts commit 492c742. Reverted #176015 on behalf of https://github.com/huydhn due to Sorry for reverting your change but a bunch of internal builds need to be updated to unblock this change D95758397 ([comment](#175924 (comment)))
This reverts commit 388d61e. Reverted #175936 on behalf of https://github.com/huydhn due to Sorry for reverting your change but a bunch of internal builds need to be updated to unblock this change D95758397 ([comment](#175924 (comment)))
This reverts commit 9b53dac. Reverted #175924 on behalf of https://github.com/huydhn due to Sorry for reverting your change but a bunch of internal builds need to be updated to unblock this change D95758397 ([comment](#175924 (comment)))
|
@liangel-02 your PR has been successfully reverted. |
`page_table` is an FA3 feature so we need to modify the function definitions in `native_functions.yaml`. if this is used with FA2, we throw an error [ghstack-poisoned]
`page_table` is an FA3 feature so we need to modify the function definitions in `native_functions.yaml`. if this is used with FA2, we throw an error [ghstack-poisoned]
`page_table` is an FA3 feature so we need to modify the function definitions in `native_functions.yaml`. if this is used with FA2, we throw an error [ghstack-poisoned]
|
@liangel-02 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
`page_table` is an FA3 feature so we need to modify the function definitions in `native_functions.yaml`. if this is used with FA2, we throw an error Differential Revision: [D95996400](https://our.internmc.facebook.com/intern/diff/D95996400) [ghstack-poisoned]
`page_table` is an FA3 feature so we need to modify the function definitions in `native_functions.yaml`. if this is used with FA2, we throw an error Differential Revision: [D95996400](https://our.internmc.facebook.com/intern/diff/D95996400) [ghstack-poisoned]
|
Starting merge as part of PR stack under #176723 |
1 similar comment
|
Starting merge as part of PR stack under #176723 |
Pull Request resolved: #175936 Approved by: https://github.com/drisspg ghstack dependencies: #175924
`aten/src/ATen/native/transformers/cuda/attention.cu`
- renamed `_flash_attention_forward` to `_flash_attention_forward_impl`. this is now the core logic and takes `optional<Tensor> out`.
- `_flash_attention_forward` is the non-out variant version and is a thin wrapper that calls `_flash_attention_forward_impl` with `out=std::nullopt`
- `_flash_attention_forward_no_dropout_inplace` is the out-variant and calls `_flash_attention_forward_impl` with `Tensor& out`
`aten/src/ATen/native/native_functions.yaml`
- i registered a new op `_flash_attention_forward_no_dropout_inplace`
`torch/_meta_registrations.py`
- added meta registration that calls `meta__flash_attention_forward` but doesn't return out tensor
`torch/nn/attention/varlen.py`
- added public `varlen_attn_out` and private custom op `_varlen_attn_out` with `mutates_args={"out"}`
`test/test_varlen_attention.py`
- added out variant to existing tests
Pull Request resolved: #176015
Approved by: https://github.com/drisspg
ghstack dependencies: #175924, #175936
Pull Request resolved: #176723 Approved by: https://github.com/drisspg ghstack dependencies: #175924, #175936, #176015
ghstack-source-id: ca17f5c Pull Request resolved: pytorch/pytorch#175924
page_tableis an FA3 feature so we need to modify the function definitions innative_functions.yaml. if this is used with FA2, we throw an errorStack from ghstack (oldest at bottom):