Skip to content

[varlen_attn for inference] add out variant#176015

Closed
liangel-02 wants to merge 22 commits intogh/liangel-02/17/basefrom
gh/liangel-02/17/head
Closed

[varlen_attn for inference] add out variant#176015
liangel-02 wants to merge 22 commits intogh/liangel-02/17/basefrom
gh/liangel-02/17/head

Conversation

@liangel-02
Copy link
Contributor

@liangel-02 liangel-02 commented Feb 27, 2026

aten/src/ATen/native/transformers/cuda/attention.cu

  • renamed _flash_attention_forward to _flash_attention_forward_impl. this is now the core logic and takes optional<Tensor> out.
  • _flash_attention_forward is the non-out variant version and is a thin wrapper that calls _flash_attention_forward_impl with out=std::nullopt
  • _flash_attention_forward_no_dropout_inplace is the out-variant and calls _flash_attention_forward_impl with Tensor& out

aten/src/ATen/native/native_functions.yaml

  • i registered a new op _flash_attention_forward_no_dropout_inplace

torch/_meta_registrations.py

  • added meta registration that calls meta__flash_attention_forward but doesn't return out tensor

torch/nn/attention/varlen.py

  • added public varlen_attn_out and private custom op _varlen_attn_out with mutates_args={"out"}

test/test_varlen_attention.py

  • added out variant to existing tests

Stack from ghstack (oldest at bottom):

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Feb 27, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/176015

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit 98609ba with merge base 4bc9d7f (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@github-actions
Copy link
Contributor

Attention! native_functions.yaml was changed

If you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info.


Caused by:

@github-actions
Copy link
Contributor

github-actions bot commented Feb 27, 2026

Attention! PyTorch one of the C-stable API file was changed

You MUST NOT change existing function declarations in this, as this header defines a stable C ABI. If you need to change the signature for a function, introduce a new v2 version of the function and modify code generation to target the new version of the function.


Caused by:

[ghstack-poisoned]
liangel-02 added a commit that referenced this pull request Feb 27, 2026
ghstack-source-id: 5104e01
Pull Request resolved: #176015
@liangel-02 liangel-02 requested a review from drisspg February 27, 2026 21:44
@liangel-02 liangel-02 changed the title add out variant wip: add out variant Feb 27, 2026
`aten/src/ATen/native/transformers/cuda/attention.cu` 

- renamed `_flash_attention_forward` to `_flash_attention_forward_impl`. this is now the core logic and takes `optional<Tensor> out`.
- `_flash_attention_forward` is the non-out variant version and is a thin wrapper that calls `_flash_attention_forward_impl` with `out=std::nullopt`
- `_flash_attention_forward_out` is the out-variant and calls `_flash_attention_forward_impl` with `Tensor& out`

`aten/src/ATen/native/native_functions.yaml`

- i registered the out variant `_flash_attention_forward.out` with type `Tensor(a!)` to indicate that it's mutable. this dispatches to `_flash_attention_forward_out` defined in `attention.cu`.

`torch/_meta_registrations.py`

- added meta registration that calls `meta__flash_attention_forward` but doesn't return out tensor

`torch/nn/attention/varlen.py`

- added public `varlen_attn_out` and private custom op `_varlen_attn_out` with `mutates_args={"out"}`

`test/test_varlen_attention.py`

- added out variant to existing tests 




[ghstack-poisoned]
liangel-02 added a commit that referenced this pull request Mar 2, 2026
ghstack-source-id: 8a6bd6c
Pull Request resolved: #176015
`aten/src/ATen/native/transformers/cuda/attention.cu` 

- renamed `_flash_attention_forward` to `_flash_attention_forward_impl`. this is now the core logic and takes `optional<Tensor> out`.
- `_flash_attention_forward` is the non-out variant version and is a thin wrapper that calls `_flash_attention_forward_impl` with `out=std::nullopt`
- `_flash_attention_forward_out` is the out-variant and calls `_flash_attention_forward_impl` with `Tensor& out`

`aten/src/ATen/native/native_functions.yaml`

- i registered the out variant `_flash_attention_forward.out` with type `Tensor(a!)` to indicate that it's mutable. this dispatches to `_flash_attention_forward_out` defined in `attention.cu`.

`torch/_meta_registrations.py`

- added meta registration that calls `meta__flash_attention_forward` but doesn't return out tensor

`torch/nn/attention/varlen.py`

- added public `varlen_attn_out` and private custom op `_varlen_attn_out` with `mutates_args={"out"}`

`test/test_varlen_attention.py`

- added out variant to existing tests 




[ghstack-poisoned]
liangel-02 added a commit that referenced this pull request Mar 2, 2026
ghstack-source-id: 5c837e5
Pull Request resolved: #176015
`aten/src/ATen/native/transformers/cuda/attention.cu` 

- renamed `_flash_attention_forward` to `_flash_attention_forward_impl`. this is now the core logic and takes `optional<Tensor> out`.
- `_flash_attention_forward` is the non-out variant version and is a thin wrapper that calls `_flash_attention_forward_impl` with `out=std::nullopt`
- `_flash_attention_forward_out` is the out-variant and calls `_flash_attention_forward_impl` with `Tensor& out`

`aten/src/ATen/native/native_functions.yaml`

- i registered the out variant `_flash_attention_forward.out` with type `Tensor(a!)` to indicate that it's mutable. this dispatches to `_flash_attention_forward_out` defined in `attention.cu`.

`torch/_meta_registrations.py`

- added meta registration that calls `meta__flash_attention_forward` but doesn't return out tensor

`torch/nn/attention/varlen.py`

- added public `varlen_attn_out` and private custom op `_varlen_attn_out` with `mutates_args={"out"}`

`test/test_varlen_attention.py`

- added out variant to existing tests 




[ghstack-poisoned]
liangel-02 added a commit that referenced this pull request Mar 3, 2026
ghstack-source-id: b85f6db
Pull Request resolved: #176015
CUDA: _flash_attention_forward
tags: nondeterministic_seeded

- func: _flash_attention_forward_out_variant(Tensor(a!) out, Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None, Tensor? page_table=None) -> (Tensor softmax_logsumexp, Tensor rng_state, Tensor unused, Tensor debug_attn_mask)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since this is not really an out op or a variant, lets acutally change this this is baiscailly a new op that is a
pseudo inplace op. Alos we probably dont need the debug_atnn_mask, unused or rng state any more right? since we dont plan to support doropout for this API and we rant using debug_attn_mask ?

can we just drop from the impl

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I know @albanD will hate this but it does feeel like the shortest path. I as you can tell from the neighborhood I have not been a very good Steward of these ops and keeping them minimal. I want to make sure we dont get yelled at :)

Copy link
Collaborator

@albanD albanD Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this not called _flash_attention_forward_(Tensor(a!) output, ....) ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changing name to _flash_attention_forward_no_dropout_inplace as discussed offline

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

discussed with the boss Alban right, if he is cool im cool

`aten/src/ATen/native/transformers/cuda/attention.cu` 

- renamed `_flash_attention_forward` to `_flash_attention_forward_impl`. this is now the core logic and takes `optional<Tensor> out`.
- `_flash_attention_forward` is the non-out variant version and is a thin wrapper that calls `_flash_attention_forward_impl` with `out=std::nullopt`
- `_flash_attention_forward_out` is the out-variant and calls `_flash_attention_forward_impl` with `Tensor& out`

`aten/src/ATen/native/native_functions.yaml`

- i registered the out variant `_flash_attention_forward.out` with type `Tensor(a!)` to indicate that it's mutable. this dispatches to `_flash_attention_forward_out` defined in `attention.cu`.

`torch/_meta_registrations.py`

- added meta registration that calls `meta__flash_attention_forward` but doesn't return out tensor

`torch/nn/attention/varlen.py`

- added public `varlen_attn_out` and private custom op `_varlen_attn_out` with `mutates_args={"out"}`

`test/test_varlen_attention.py`

- added out variant to existing tests 




[ghstack-poisoned]
liangel-02 added a commit that referenced this pull request Mar 4, 2026
ghstack-source-id: 91a5618
Pull Request resolved: #176015
`aten/src/ATen/native/transformers/cuda/attention.cu` 

- renamed `_flash_attention_forward` to `_flash_attention_forward_impl`. this is now the core logic and takes `optional<Tensor> out`.
- `_flash_attention_forward` is the non-out variant version and is a thin wrapper that calls `_flash_attention_forward_impl` with `out=std::nullopt`
- `_flash_attention_forward_out` is the out-variant and calls `_flash_attention_forward_impl` with `Tensor& out`

`aten/src/ATen/native/native_functions.yaml`

- i registered the out variant `_flash_attention_forward.out` with type `Tensor(a!)` to indicate that it's mutable. this dispatches to `_flash_attention_forward_out` defined in `attention.cu`.

`torch/_meta_registrations.py`

- added meta registration that calls `meta__flash_attention_forward` but doesn't return out tensor

`torch/nn/attention/varlen.py`

- added public `varlen_attn_out` and private custom op `_varlen_attn_out` with `mutates_args={"out"}`

`test/test_varlen_attention.py`

- added out variant to existing tests 




[ghstack-poisoned]
liangel-02 added a commit that referenced this pull request Mar 4, 2026
ghstack-source-id: 0101572
Pull Request resolved: #176015
`aten/src/ATen/native/transformers/cuda/attention.cu` 

- renamed `_flash_attention_forward` to `_flash_attention_forward_impl`. this is now the core logic and takes `optional<Tensor> out`.
- `_flash_attention_forward` is the non-out variant version and is a thin wrapper that calls `_flash_attention_forward_impl` with `out=std::nullopt`
- `_flash_attention_forward_out` is the out-variant and calls `_flash_attention_forward_impl` with `Tensor& out`

`aten/src/ATen/native/native_functions.yaml`

- i registered the out variant `_flash_attention_forward.out` with type `Tensor(a!)` to indicate that it's mutable. this dispatches to `_flash_attention_forward_out` defined in `attention.cu`.

`torch/_meta_registrations.py`

- added meta registration that calls `meta__flash_attention_forward` but doesn't return out tensor

`torch/nn/attention/varlen.py`

- added public `varlen_attn_out` and private custom op `_varlen_attn_out` with `mutates_args={"out"}`

`test/test_varlen_attention.py`

- added out variant to existing tests 




[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Mar 7, 2026
pytorchmergebot added a commit that referenced this pull request Mar 7, 2026
This reverts commit f1e413e.

Reverted #176015 on behalf of https://github.com/zou3519 due to sorry I think this broke inductor rocm ([comment](#175897 (comment)))
@pytorchmergebot
Copy link
Collaborator

@liangel-02 your PR has been reverted as part of the stack under #175897.

@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Mar 7, 2026
`aten/src/ATen/native/transformers/cuda/attention.cu` 

- renamed `_flash_attention_forward` to `_flash_attention_forward_impl`. this is now the core logic and takes `optional<Tensor> out`.
- `_flash_attention_forward` is the non-out variant version and is a thin wrapper that calls `_flash_attention_forward_impl` with `out=std::nullopt`
- `_flash_attention_forward_no_dropout_inplace` is the out-variant and calls `_flash_attention_forward_impl` with `Tensor& out`

`aten/src/ATen/native/native_functions.yaml`

- i registered a new op `_flash_attention_forward_no_dropout_inplace`

`torch/_meta_registrations.py`

- added meta registration that calls `meta__flash_attention_forward` but doesn't return out tensor

`torch/nn/attention/varlen.py`

- added public `varlen_attn_out` and private custom op `_varlen_attn_out` with `mutates_args={"out"}`

`test/test_varlen_attention.py`

- added out variant to existing tests 




[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #176723

1 similar comment
@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #176723

pytorchmergebot pushed a commit that referenced this pull request Mar 8, 2026
pytorchmergebot added a commit that referenced this pull request Mar 10, 2026
This reverts commit 492c742.

Reverted #176015 on behalf of https://github.com/huydhn due to Sorry for reverting your change but a bunch of internal builds need to be updated to unblock this change D95758397 ([comment](#175924 (comment)))
@pytorchmergebot
Copy link
Collaborator

@liangel-02 your PR has been reverted as part of the stack under #175924.

`aten/src/ATen/native/transformers/cuda/attention.cu` 

- renamed `_flash_attention_forward` to `_flash_attention_forward_impl`. this is now the core logic and takes `optional<Tensor> out`.
- `_flash_attention_forward` is the non-out variant version and is a thin wrapper that calls `_flash_attention_forward_impl` with `out=std::nullopt`
- `_flash_attention_forward_no_dropout_inplace` is the out-variant and calls `_flash_attention_forward_impl` with `Tensor& out`

`aten/src/ATen/native/native_functions.yaml`

- i registered a new op `_flash_attention_forward_no_dropout_inplace`

`torch/_meta_registrations.py`

- added meta registration that calls `meta__flash_attention_forward` but doesn't return out tensor

`torch/nn/attention/varlen.py`

- added public `varlen_attn_out` and private custom op `_varlen_attn_out` with `mutates_args={"out"}`

`test/test_varlen_attention.py`

- added out variant to existing tests 




[ghstack-poisoned]
`aten/src/ATen/native/transformers/cuda/attention.cu` 

- renamed `_flash_attention_forward` to `_flash_attention_forward_impl`. this is now the core logic and takes `optional<Tensor> out`.
- `_flash_attention_forward` is the non-out variant version and is a thin wrapper that calls `_flash_attention_forward_impl` with `out=std::nullopt`
- `_flash_attention_forward_no_dropout_inplace` is the out-variant and calls `_flash_attention_forward_impl` with `Tensor& out`

`aten/src/ATen/native/native_functions.yaml`

- i registered a new op `_flash_attention_forward_no_dropout_inplace`

`torch/_meta_registrations.py`

- added meta registration that calls `meta__flash_attention_forward` but doesn't return out tensor

`torch/nn/attention/varlen.py`

- added public `varlen_attn_out` and private custom op `_varlen_attn_out` with `mutates_args={"out"}`

`test/test_varlen_attention.py`

- added out variant to existing tests 




[ghstack-poisoned]
`aten/src/ATen/native/transformers/cuda/attention.cu` 

- renamed `_flash_attention_forward` to `_flash_attention_forward_impl`. this is now the core logic and takes `optional<Tensor> out`.
- `_flash_attention_forward` is the non-out variant version and is a thin wrapper that calls `_flash_attention_forward_impl` with `out=std::nullopt`
- `_flash_attention_forward_no_dropout_inplace` is the out-variant and calls `_flash_attention_forward_impl` with `Tensor& out`

`aten/src/ATen/native/native_functions.yaml`

- i registered a new op `_flash_attention_forward_no_dropout_inplace`

`torch/_meta_registrations.py`

- added meta registration that calls `meta__flash_attention_forward` but doesn't return out tensor

`torch/nn/attention/varlen.py`

- added public `varlen_attn_out` and private custom op `_varlen_attn_out` with `mutates_args={"out"}`

`test/test_varlen_attention.py`

- added out variant to existing tests 




[ghstack-poisoned]
@liangel-02
Copy link
Contributor Author

@liangel-02 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

`aten/src/ATen/native/transformers/cuda/attention.cu` 

- renamed `_flash_attention_forward` to `_flash_attention_forward_impl`. this is now the core logic and takes `optional<Tensor> out`.
- `_flash_attention_forward` is the non-out variant version and is a thin wrapper that calls `_flash_attention_forward_impl` with `out=std::nullopt`
- `_flash_attention_forward_no_dropout_inplace` is the out-variant and calls `_flash_attention_forward_impl` with `Tensor& out`

`aten/src/ATen/native/native_functions.yaml`

- i registered a new op `_flash_attention_forward_no_dropout_inplace`

`torch/_meta_registrations.py`

- added meta registration that calls `meta__flash_attention_forward` but doesn't return out tensor

`torch/nn/attention/varlen.py`

- added public `varlen_attn_out` and private custom op `_varlen_attn_out` with `mutates_args={"out"}`

`test/test_varlen_attention.py`

- added out variant to existing tests 


Differential Revision: [D95996399](https://our.internmc.facebook.com/intern/diff/D95996399)

[ghstack-poisoned]
`aten/src/ATen/native/transformers/cuda/attention.cu` 

- renamed `_flash_attention_forward` to `_flash_attention_forward_impl`. this is now the core logic and takes `optional<Tensor> out`.
- `_flash_attention_forward` is the non-out variant version and is a thin wrapper that calls `_flash_attention_forward_impl` with `out=std::nullopt`
- `_flash_attention_forward_no_dropout_inplace` is the out-variant and calls `_flash_attention_forward_impl` with `Tensor& out`

`aten/src/ATen/native/native_functions.yaml`

- i registered a new op `_flash_attention_forward_no_dropout_inplace`

`torch/_meta_registrations.py`

- added meta registration that calls `meta__flash_attention_forward` but doesn't return out tensor

`torch/nn/attention/varlen.py`

- added public `varlen_attn_out` and private custom op `_varlen_attn_out` with `mutates_args={"out"}`

`test/test_varlen_attention.py`

- added out variant to existing tests 


Differential Revision: [D95996399](https://our.internmc.facebook.com/intern/diff/D95996399)

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #176723

1 similar comment
@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #176723

pytorchmergebot pushed a commit that referenced this pull request Mar 11, 2026
sandy-gags pushed a commit to sandy-gags/pytorch that referenced this pull request Mar 12, 2026
ghstack-source-id: dba65b3
Pull Request resolved: pytorch/pytorch#176015
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants