Skip to content

[Submodule] Turning flash-attention integration into 3rd party submod (#144120)#146372

Closed
drisspg wants to merge 1 commit intopytorch:mainfrom
drisspg:export-D68502879
Closed

[Submodule] Turning flash-attention integration into 3rd party submod (#144120)#146372
drisspg wants to merge 1 commit intopytorch:mainfrom
drisspg:export-D68502879

Conversation

@drisspg
Copy link
Contributor

@drisspg drisspg commented Feb 4, 2025

Summary:

Summary

Sticky points

Cuda-graph rng handling has changed / deviated from original implementation. We will be left with a dangling 'offset' val and confusing naming due to BC

Dependencies

Other Points

  • The BC linter is complaining about losing generate.py and its functions which is not real BC surface
    cc albanD

imported-using-ghimport

Test Plan:
Imported from OSS

Building in dev
buck build @//mode/dev-nosan -c fbcode.nvcc_arch=h100a //caffe2:ATen-cu --show-full-output

I and Nming the .so I do see that the flash symbols are correctly named:

0000000001c3dfb0 t pytorch_flash::run_mha_bwd(pytorch_flash::Flash_bwd_params&, CUstream_st*)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()#7}::operator()() const
0000000001c36080 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#2}::operator()() const::{lambda()#1}::operator()() const::{lambda()#6}::operator()() const
0000000001c360e0 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#2}::operator()() const::{lambda()#1}::operator()() const::{lambda()#7}::operator()() const
0000000001c35fc0 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()#6}::operator()() const
0000000001c36020 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()#7}::operator()() const

Reviewed By: vkuzo

Differential Revision: D68502879

Pulled By: drisspg

cc @albanD

@pytorch-bot
Copy link

pytorch-bot bot commented Feb 4, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/146372

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Unrelated Failure

As of commit 16e81a4 with merge base 651e6aa (image):

NEW FAILURE - The following job has failed:

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D68502879

@github-actions
Copy link
Contributor

github-actions bot commented Feb 4, 2025

Attention! native_functions.yaml was changed

If you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info.


Caused by:

@github-actions
Copy link
Contributor

github-actions bot commented Feb 4, 2025

Attention! PyTorch one of the C-stable API file was changed

You MUST NOT change existing function declarations in this, as this header defines a stable C ABI. If you need to change the signature for a function, introduce a new v2 version of the function and modify code generation to target the new version of the function.


Caused by:

@drisspg drisspg added ciflow/trunk Trigger trunk jobs on your pull request topic: not user facing topic category module: sdpa All things related to torch.nn.functional.scaled_dot_product_attentiion labels Feb 4, 2025
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D68502879

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D68502879

drisspg added a commit to drisspg/pytorch that referenced this pull request Feb 5, 2025
…pytorch#146372)

Summary:
Pull Request resolved: pytorch#146372

Pull Request resolved: pytorch#144120

# Summary

### Sticky points

Cuda-graph rng handling has changed / deviated from original implementation. We will be left with a dangling 'offset' val and confusing naming due to BC

## Dependencies
- Flash PR: Dao-AILab/flash-attention#1419

### Other Points
- The BC linter is complaining about losing generate.py and its functions which is not real BC surface
cc albanD

imported-using-ghimport

Test Plan:
Imported from OSS

Building in dev
`buck build @//mode/dev-nosan -c fbcode.nvcc_arch=h100a  //caffe2:ATen-cu --show-full-output    `

I and Nming the .so I do see that the flash symbols are correctly named:
```
0000000001c3dfb0 t pytorch_flash::run_mha_bwd(pytorch_flash::Flash_bwd_params&, CUstream_st*)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#7}::operator()() const
0000000001c36080 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#2}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#6}::operator()() const
0000000001c360e0 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#2}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#7}::operator()() const
0000000001c35fc0 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#6}::operator()() const
0000000001c36020 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#7}::operator()() const
```

Reviewed By: vkuzo

Differential Revision: D68502879

Pulled By: drisspg
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D68502879

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D68502879

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D68502879

drisspg added a commit to drisspg/pytorch that referenced this pull request Feb 8, 2025
…pytorch#146372)

Summary:
Pull Request resolved: pytorch#146372

Pull Request resolved: pytorch#144120

# Summary

### Sticky points

Cuda-graph rng handling has changed / deviated from original implementation. We will be left with a dangling 'offset' val and confusing naming due to BC

## Dependencies
- Flash PR: Dao-AILab/flash-attention#1419

### Other Points
- The BC linter is complaining about losing generate.py and its functions which is not real BC surface
cc albanD

imported-using-ghimport

Test Plan:
Imported from OSS

Building in dev
`buck build @//mode/dev-nosan -c fbcode.nvcc_arch=h100a  //caffe2:ATen-cu --show-full-output    `

I and Nming the .so I do see that the flash symbols are correctly named:
```
0000000001c3dfb0 t pytorch_flash::run_mha_bwd(pytorch_flash::Flash_bwd_params&, CUstream_st*)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#7}::operator()() const
0000000001c36080 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#2}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#6}::operator()() const
0000000001c360e0 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#2}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#7}::operator()() const
0000000001c35fc0 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#6}::operator()() const
0000000001c36020 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#7}::operator()() const
```

Reviewed By: vkuzo

Differential Revision: D68502879

Pulled By: drisspg
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D68502879

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D68502879

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D68502879

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D68502879

drisspg added a commit to drisspg/pytorch that referenced this pull request Feb 23, 2025
…pytorch#146372)

Summary:
Pull Request resolved: pytorch#146372

Pull Request resolved: pytorch#144120

# Summary

### Sticky points

Cuda-graph rng handling has changed / deviated from original implementation. We will be left with a dangling 'offset' val and confusing naming due to BC

## Dependencies
- Flash PR: Dao-AILab/flash-attention#1419

### Other Points
- The BC linter is complaining about losing generate.py and its functions which is not real BC surface
cc albanD

imported-using-ghimport

Test Plan:
Imported from OSS

Building in dev
`buck build @//mode/dev-nosan -c fbcode.nvcc_arch=h100a  //caffe2:ATen-cu --show-full-output    `

I and Nming the .so I do see that the flash symbols are correctly named:
```
0000000001c3dfb0 t pytorch_flash::run_mha_bwd(pytorch_flash::Flash_bwd_params&, CUstream_st*)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#7}::operator()() const
0000000001c36080 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#2}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#6}::operator()() const
0000000001c360e0 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#2}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#7}::operator()() const
0000000001c35fc0 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#6}::operator()() const
0000000001c36020 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#7}::operator()() const
```

Reviewed By: vkuzo

Differential Revision: D68502879

Pulled By: drisspg
@drisspg drisspg added the ci-no-td Do not run TD on this PR label Feb 23, 2025
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D68502879

…pytorch#146372)

Summary:
Pull Request resolved: pytorch#146372

Pull Request resolved: pytorch#144120

# Summary

### Sticky points

Cuda-graph rng handling has changed / deviated from original implementation. We will be left with a dangling 'offset' val and confusing naming due to BC

## Dependencies
- Flash PR: Dao-AILab/flash-attention#1419

### Other Points
- The BC linter is complaining about losing generate.py and its functions which is not real BC surface
cc albanD

imported-using-ghimport

Test Plan:
Imported from OSS

Building in dev
`buck build @//mode/dev-nosan -c fbcode.nvcc_arch=h100a  //caffe2:ATen-cu --show-full-output    `

I and Nming the .so I do see that the flash symbols are correctly named:
```
0000000001c3dfb0 t pytorch_flash::run_mha_bwd(pytorch_flash::Flash_bwd_params&, CUstream_st*)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#7}::operator()() const
0000000001c36080 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#2}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#6}::operator()() const
0000000001c360e0 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#2}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#7}::operator()() const
0000000001c35fc0 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#6}::operator()() const
0000000001c36020 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#7}::operator()() const
```

Reviewed By: vkuzo

Differential Revision: D68502879

Pulled By: drisspg
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D68502879

@facebook-github-bot
Copy link
Contributor

@pytorchbot merge -i

(Initiating merge automatically since Phabricator Diff has merged, merging with -i because oss signals were bypassed internally)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 2 checks: pull / unstable-linux-focal-cuda12.4-py3.10-gcc9-sm89-xfail / build, pull / linux-jammy-py3.9-gcc11 / test (backwards_compat, 1, 1, linux.2xlarge)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Feb 28, 2025
)

### Summary
#146372 changed the op signature of `_scaled_dot_product_flash_attention` and as a consequence DTensor needs to change its sharding defined at https://github.com/pytorch/pytorch/blob/40ad5e01dff05c7d64e070fb01683820e678f788/torch/distributed/tensor/_ops/_matrix_ops.py#L232

### Test
`pytest test/distributed/tensor/test_attention.py`

### Follow-up
It's still unclear why the CP unit tests were not run over the original PR which is BC-breaking.

Pull Request resolved: #148125
Approved by: https://github.com/tianyu-l, https://github.com/fegin
aditew01 pushed a commit that referenced this pull request Feb 28, 2025
…#144120) (#146372)

Summary:

# Summary

### Sticky points

Cuda-graph rng handling has changed / deviated from original implementation. We will be left with a dangling 'offset' val and confusing naming due to BC

## Dependencies
- Flash PR: Dao-AILab/flash-attention#1419

### Other Points
- The BC linter is complaining about losing generate.py and its functions which is not real BC surface
cc albanD

imported-using-ghimport

Test Plan:
Imported from OSS

Building in dev
`buck build @//mode/dev-nosan -c fbcode.nvcc_arch=h100a  //caffe2:ATen-cu --show-full-output    `

I and Nming the .so I do see that the flash symbols are correctly named:
```
0000000001c3dfb0 t pytorch_flash::run_mha_bwd(pytorch_flash::Flash_bwd_params&, CUstream_st*)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()#7}::operator()() const
0000000001c36080 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#2}::operator()() const::{lambda()#1}::operator()() const::{lambda()#6}::operator()() const
0000000001c360e0 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#2}::operator()() const::{lambda()#1}::operator()() const::{lambda()#7}::operator()() const
0000000001c35fc0 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()#6}::operator()() const
0000000001c36020 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()#7}::operator()() const
```

Reviewed By: vkuzo

Differential Revision: D68502879

Pulled By: drisspg

Pull Request resolved: #146372
Approved by: https://github.com/jbschlosser
Priya2698 added a commit to NVIDIA/Fuser that referenced this pull request Mar 17, 2025
pytorch/pytorch#146372 changes the flash
attention API.

```
// [Note] BC breaking change to flash seed/offset
// Previously: Used separate tensors for philox_seed and philox_offset, sometimes on CPU, sometimes on CUDA
// FlashAttention change: Now uses a single uint64_t[2] tensor on device containing both seed and offset
// Implementation: Renamed "seed" → "rng_state" (contains both seed+offset) and "offset" → "_unused"
```

~In nvfuser API, `SdpaFwdOp` now returns a `rng_state` and `SdpaBwdOp`
expects a `rng_state`. `_unused` is ignored.~

In nvfuser API, based on torch version, `philox_seed` is now
`uint64_t[2]` and philox offset is a empty `uint64_t` tensor on device.
I chose to keep the same semantics as PyTorch since it avoids
packing/unpacking PyTorch outputs when testing if I were to split the
`rng_state` output of PyTorch into `philox_seed` and `philox_offset`.
The latter approach would keep the dimensions of both tensors the same,
but since we're changing `device` and `dtype`, they will need to be
created distinctly for the two cases. This also allows us to switch to
the single-parameter version in the future more easily if desired.

---------

Co-authored-by: Jingyue Wu <wujingyue@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request fb-exported Merged module: sdpa All things related to torch.nn.functional.scaled_dot_product_attentiion skip-pr-sanity-checks suppress-bc-linter Suppresses the failures of API backward-compatibility linter (Lint/bc_linter) topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants