[varlen_attn for inference] add out variant#176015
[varlen_attn for inference] add out variant#176015liangel-02 wants to merge 22 commits intogh/liangel-02/17/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/176015
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit 98609ba with merge base 4bc9d7f ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Attention! native_functions.yaml was changedIf you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info. Caused by: |
Attention! PyTorch one of the C-stable API file was changedYou MUST NOT change existing function declarations in this, as this header defines a stable C ABI. If you need to change the signature for a function, introduce a new v2 version of the function and modify code generation to target the new version of the function. Caused by: |
[ghstack-poisoned]
`aten/src/ATen/native/transformers/cuda/attention.cu`
- renamed `_flash_attention_forward` to `_flash_attention_forward_impl`. this is now the core logic and takes `optional<Tensor> out`.
- `_flash_attention_forward` is the non-out variant version and is a thin wrapper that calls `_flash_attention_forward_impl` with `out=std::nullopt`
- `_flash_attention_forward_out` is the out-variant and calls `_flash_attention_forward_impl` with `Tensor& out`
`aten/src/ATen/native/native_functions.yaml`
- i registered the out variant `_flash_attention_forward.out` with type `Tensor(a!)` to indicate that it's mutable. this dispatches to `_flash_attention_forward_out` defined in `attention.cu`.
`torch/_meta_registrations.py`
- added meta registration that calls `meta__flash_attention_forward` but doesn't return out tensor
`torch/nn/attention/varlen.py`
- added public `varlen_attn_out` and private custom op `_varlen_attn_out` with `mutates_args={"out"}`
`test/test_varlen_attention.py`
- added out variant to existing tests
[ghstack-poisoned]
`aten/src/ATen/native/transformers/cuda/attention.cu`
- renamed `_flash_attention_forward` to `_flash_attention_forward_impl`. this is now the core logic and takes `optional<Tensor> out`.
- `_flash_attention_forward` is the non-out variant version and is a thin wrapper that calls `_flash_attention_forward_impl` with `out=std::nullopt`
- `_flash_attention_forward_out` is the out-variant and calls `_flash_attention_forward_impl` with `Tensor& out`
`aten/src/ATen/native/native_functions.yaml`
- i registered the out variant `_flash_attention_forward.out` with type `Tensor(a!)` to indicate that it's mutable. this dispatches to `_flash_attention_forward_out` defined in `attention.cu`.
`torch/_meta_registrations.py`
- added meta registration that calls `meta__flash_attention_forward` but doesn't return out tensor
`torch/nn/attention/varlen.py`
- added public `varlen_attn_out` and private custom op `_varlen_attn_out` with `mutates_args={"out"}`
`test/test_varlen_attention.py`
- added out variant to existing tests
[ghstack-poisoned]
`aten/src/ATen/native/transformers/cuda/attention.cu`
- renamed `_flash_attention_forward` to `_flash_attention_forward_impl`. this is now the core logic and takes `optional<Tensor> out`.
- `_flash_attention_forward` is the non-out variant version and is a thin wrapper that calls `_flash_attention_forward_impl` with `out=std::nullopt`
- `_flash_attention_forward_out` is the out-variant and calls `_flash_attention_forward_impl` with `Tensor& out`
`aten/src/ATen/native/native_functions.yaml`
- i registered the out variant `_flash_attention_forward.out` with type `Tensor(a!)` to indicate that it's mutable. this dispatches to `_flash_attention_forward_out` defined in `attention.cu`.
`torch/_meta_registrations.py`
- added meta registration that calls `meta__flash_attention_forward` but doesn't return out tensor
`torch/nn/attention/varlen.py`
- added public `varlen_attn_out` and private custom op `_varlen_attn_out` with `mutates_args={"out"}`
`test/test_varlen_attention.py`
- added out variant to existing tests
[ghstack-poisoned]
| CUDA: _flash_attention_forward | ||
| tags: nondeterministic_seeded | ||
|
|
||
| - func: _flash_attention_forward_out_variant(Tensor(a!) out, Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None, Tensor? page_table=None) -> (Tensor softmax_logsumexp, Tensor rng_state, Tensor unused, Tensor debug_attn_mask) |
There was a problem hiding this comment.
since this is not really an out op or a variant, lets acutally change this this is baiscailly a new op that is a
pseudo inplace op. Alos we probably dont need the debug_atnn_mask, unused or rng state any more right? since we dont plan to support doropout for this API and we rant using debug_attn_mask ?
can we just drop from the impl
There was a problem hiding this comment.
Also I know @albanD will hate this but it does feeel like the shortest path. I as you can tell from the neighborhood I have not been a very good Steward of these ops and keeping them minimal. I want to make sure we dont get yelled at :)
There was a problem hiding this comment.
Why is this not called _flash_attention_forward_(Tensor(a!) output, ....) ?
There was a problem hiding this comment.
changing name to _flash_attention_forward_no_dropout_inplace as discussed offline
There was a problem hiding this comment.
discussed with the boss Alban right, if he is cool im cool
`aten/src/ATen/native/transformers/cuda/attention.cu`
- renamed `_flash_attention_forward` to `_flash_attention_forward_impl`. this is now the core logic and takes `optional<Tensor> out`.
- `_flash_attention_forward` is the non-out variant version and is a thin wrapper that calls `_flash_attention_forward_impl` with `out=std::nullopt`
- `_flash_attention_forward_out` is the out-variant and calls `_flash_attention_forward_impl` with `Tensor& out`
`aten/src/ATen/native/native_functions.yaml`
- i registered the out variant `_flash_attention_forward.out` with type `Tensor(a!)` to indicate that it's mutable. this dispatches to `_flash_attention_forward_out` defined in `attention.cu`.
`torch/_meta_registrations.py`
- added meta registration that calls `meta__flash_attention_forward` but doesn't return out tensor
`torch/nn/attention/varlen.py`
- added public `varlen_attn_out` and private custom op `_varlen_attn_out` with `mutates_args={"out"}`
`test/test_varlen_attention.py`
- added out variant to existing tests
[ghstack-poisoned]
`aten/src/ATen/native/transformers/cuda/attention.cu`
- renamed `_flash_attention_forward` to `_flash_attention_forward_impl`. this is now the core logic and takes `optional<Tensor> out`.
- `_flash_attention_forward` is the non-out variant version and is a thin wrapper that calls `_flash_attention_forward_impl` with `out=std::nullopt`
- `_flash_attention_forward_out` is the out-variant and calls `_flash_attention_forward_impl` with `Tensor& out`
`aten/src/ATen/native/native_functions.yaml`
- i registered the out variant `_flash_attention_forward.out` with type `Tensor(a!)` to indicate that it's mutable. this dispatches to `_flash_attention_forward_out` defined in `attention.cu`.
`torch/_meta_registrations.py`
- added meta registration that calls `meta__flash_attention_forward` but doesn't return out tensor
`torch/nn/attention/varlen.py`
- added public `varlen_attn_out` and private custom op `_varlen_attn_out` with `mutates_args={"out"}`
`test/test_varlen_attention.py`
- added out variant to existing tests
[ghstack-poisoned]
`aten/src/ATen/native/transformers/cuda/attention.cu`
- renamed `_flash_attention_forward` to `_flash_attention_forward_impl`. this is now the core logic and takes `optional<Tensor> out`.
- `_flash_attention_forward` is the non-out variant version and is a thin wrapper that calls `_flash_attention_forward_impl` with `out=std::nullopt`
- `_flash_attention_forward_out` is the out-variant and calls `_flash_attention_forward_impl` with `Tensor& out`
`aten/src/ATen/native/native_functions.yaml`
- i registered the out variant `_flash_attention_forward.out` with type `Tensor(a!)` to indicate that it's mutable. this dispatches to `_flash_attention_forward_out` defined in `attention.cu`.
`torch/_meta_registrations.py`
- added meta registration that calls `meta__flash_attention_forward` but doesn't return out tensor
`torch/nn/attention/varlen.py`
- added public `varlen_attn_out` and private custom op `_varlen_attn_out` with `mutates_args={"out"}`
`test/test_varlen_attention.py`
- added out variant to existing tests
[ghstack-poisoned]
This reverts commit f1e413e. Reverted #176015 on behalf of https://github.com/zou3519 due to sorry I think this broke inductor rocm ([comment](#175897 (comment)))
|
@liangel-02 your PR has been reverted as part of the stack under #175897. |
`aten/src/ATen/native/transformers/cuda/attention.cu`
- renamed `_flash_attention_forward` to `_flash_attention_forward_impl`. this is now the core logic and takes `optional<Tensor> out`.
- `_flash_attention_forward` is the non-out variant version and is a thin wrapper that calls `_flash_attention_forward_impl` with `out=std::nullopt`
- `_flash_attention_forward_no_dropout_inplace` is the out-variant and calls `_flash_attention_forward_impl` with `Tensor& out`
`aten/src/ATen/native/native_functions.yaml`
- i registered a new op `_flash_attention_forward_no_dropout_inplace`
`torch/_meta_registrations.py`
- added meta registration that calls `meta__flash_attention_forward` but doesn't return out tensor
`torch/nn/attention/varlen.py`
- added public `varlen_attn_out` and private custom op `_varlen_attn_out` with `mutates_args={"out"}`
`test/test_varlen_attention.py`
- added out variant to existing tests
[ghstack-poisoned]
|
Starting merge as part of PR stack under #176723 |
1 similar comment
|
Starting merge as part of PR stack under #176723 |
This reverts commit 492c742. Reverted #176015 on behalf of https://github.com/huydhn due to Sorry for reverting your change but a bunch of internal builds need to be updated to unblock this change D95758397 ([comment](#175924 (comment)))
|
@liangel-02 your PR has been reverted as part of the stack under #175924. |
`aten/src/ATen/native/transformers/cuda/attention.cu`
- renamed `_flash_attention_forward` to `_flash_attention_forward_impl`. this is now the core logic and takes `optional<Tensor> out`.
- `_flash_attention_forward` is the non-out variant version and is a thin wrapper that calls `_flash_attention_forward_impl` with `out=std::nullopt`
- `_flash_attention_forward_no_dropout_inplace` is the out-variant and calls `_flash_attention_forward_impl` with `Tensor& out`
`aten/src/ATen/native/native_functions.yaml`
- i registered a new op `_flash_attention_forward_no_dropout_inplace`
`torch/_meta_registrations.py`
- added meta registration that calls `meta__flash_attention_forward` but doesn't return out tensor
`torch/nn/attention/varlen.py`
- added public `varlen_attn_out` and private custom op `_varlen_attn_out` with `mutates_args={"out"}`
`test/test_varlen_attention.py`
- added out variant to existing tests
[ghstack-poisoned]
`aten/src/ATen/native/transformers/cuda/attention.cu`
- renamed `_flash_attention_forward` to `_flash_attention_forward_impl`. this is now the core logic and takes `optional<Tensor> out`.
- `_flash_attention_forward` is the non-out variant version and is a thin wrapper that calls `_flash_attention_forward_impl` with `out=std::nullopt`
- `_flash_attention_forward_no_dropout_inplace` is the out-variant and calls `_flash_attention_forward_impl` with `Tensor& out`
`aten/src/ATen/native/native_functions.yaml`
- i registered a new op `_flash_attention_forward_no_dropout_inplace`
`torch/_meta_registrations.py`
- added meta registration that calls `meta__flash_attention_forward` but doesn't return out tensor
`torch/nn/attention/varlen.py`
- added public `varlen_attn_out` and private custom op `_varlen_attn_out` with `mutates_args={"out"}`
`test/test_varlen_attention.py`
- added out variant to existing tests
[ghstack-poisoned]
`aten/src/ATen/native/transformers/cuda/attention.cu`
- renamed `_flash_attention_forward` to `_flash_attention_forward_impl`. this is now the core logic and takes `optional<Tensor> out`.
- `_flash_attention_forward` is the non-out variant version and is a thin wrapper that calls `_flash_attention_forward_impl` with `out=std::nullopt`
- `_flash_attention_forward_no_dropout_inplace` is the out-variant and calls `_flash_attention_forward_impl` with `Tensor& out`
`aten/src/ATen/native/native_functions.yaml`
- i registered a new op `_flash_attention_forward_no_dropout_inplace`
`torch/_meta_registrations.py`
- added meta registration that calls `meta__flash_attention_forward` but doesn't return out tensor
`torch/nn/attention/varlen.py`
- added public `varlen_attn_out` and private custom op `_varlen_attn_out` with `mutates_args={"out"}`
`test/test_varlen_attention.py`
- added out variant to existing tests
[ghstack-poisoned]
|
@liangel-02 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
`aten/src/ATen/native/transformers/cuda/attention.cu`
- renamed `_flash_attention_forward` to `_flash_attention_forward_impl`. this is now the core logic and takes `optional<Tensor> out`.
- `_flash_attention_forward` is the non-out variant version and is a thin wrapper that calls `_flash_attention_forward_impl` with `out=std::nullopt`
- `_flash_attention_forward_no_dropout_inplace` is the out-variant and calls `_flash_attention_forward_impl` with `Tensor& out`
`aten/src/ATen/native/native_functions.yaml`
- i registered a new op `_flash_attention_forward_no_dropout_inplace`
`torch/_meta_registrations.py`
- added meta registration that calls `meta__flash_attention_forward` but doesn't return out tensor
`torch/nn/attention/varlen.py`
- added public `varlen_attn_out` and private custom op `_varlen_attn_out` with `mutates_args={"out"}`
`test/test_varlen_attention.py`
- added out variant to existing tests
Differential Revision: [D95996399](https://our.internmc.facebook.com/intern/diff/D95996399)
[ghstack-poisoned]
`aten/src/ATen/native/transformers/cuda/attention.cu`
- renamed `_flash_attention_forward` to `_flash_attention_forward_impl`. this is now the core logic and takes `optional<Tensor> out`.
- `_flash_attention_forward` is the non-out variant version and is a thin wrapper that calls `_flash_attention_forward_impl` with `out=std::nullopt`
- `_flash_attention_forward_no_dropout_inplace` is the out-variant and calls `_flash_attention_forward_impl` with `Tensor& out`
`aten/src/ATen/native/native_functions.yaml`
- i registered a new op `_flash_attention_forward_no_dropout_inplace`
`torch/_meta_registrations.py`
- added meta registration that calls `meta__flash_attention_forward` but doesn't return out tensor
`torch/nn/attention/varlen.py`
- added public `varlen_attn_out` and private custom op `_varlen_attn_out` with `mutates_args={"out"}`
`test/test_varlen_attention.py`
- added out variant to existing tests
Differential Revision: [D95996399](https://our.internmc.facebook.com/intern/diff/D95996399)
[ghstack-poisoned]
|
Starting merge as part of PR stack under #176723 |
1 similar comment
|
Starting merge as part of PR stack under #176723 |
Pull Request resolved: #176723 Approved by: https://github.com/drisspg ghstack dependencies: #175924, #175936, #176015
ghstack-source-id: dba65b3 Pull Request resolved: pytorch/pytorch#176015
aten/src/ATen/native/transformers/cuda/attention.cu_flash_attention_forwardto_flash_attention_forward_impl. this is now the core logic and takesoptional<Tensor> out._flash_attention_forwardis the non-out variant version and is a thin wrapper that calls_flash_attention_forward_implwithout=std::nullopt_flash_attention_forward_no_dropout_inplaceis the out-variant and calls_flash_attention_forward_implwithTensor& outaten/src/ATen/native/native_functions.yaml_flash_attention_forward_no_dropout_inplacetorch/_meta_registrations.pymeta__flash_attention_forwardbut doesn't return out tensortorch/nn/attention/varlen.pyvarlen_attn_outand private custom op_varlen_attn_outwithmutates_args={"out"}test/test_varlen_attention.pyStack from ghstack (oldest at bottom):