[Submodule] Turning flash-attention integration into 3rd party submod (#144120)#146372
[Submodule] Turning flash-attention integration into 3rd party submod (#144120)#146372drisspg wants to merge 1 commit intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/146372
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 1 Unrelated FailureAs of commit 16e81a4 with merge base 651e6aa ( NEW FAILURE - The following job has failed:
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
This pull request was exported from Phabricator. Differential Revision: D68502879 |
Attention! native_functions.yaml was changedIf you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info. Caused by: |
Attention! PyTorch one of the C-stable API file was changedYou MUST NOT change existing function declarations in this, as this header defines a stable C ABI. If you need to change the signature for a function, introduce a new v2 version of the function and modify code generation to target the new version of the function. Caused by: |
|
This pull request was exported from Phabricator. Differential Revision: D68502879 |
9741ff1 to
d98ddeb
Compare
|
This pull request was exported from Phabricator. Differential Revision: D68502879 |
d98ddeb to
9d82be5
Compare
…pytorch#146372) Summary: Pull Request resolved: pytorch#146372 Pull Request resolved: pytorch#144120 # Summary ### Sticky points Cuda-graph rng handling has changed / deviated from original implementation. We will be left with a dangling 'offset' val and confusing naming due to BC ## Dependencies - Flash PR: Dao-AILab/flash-attention#1419 ### Other Points - The BC linter is complaining about losing generate.py and its functions which is not real BC surface cc albanD imported-using-ghimport Test Plan: Imported from OSS Building in dev `buck build @//mode/dev-nosan -c fbcode.nvcc_arch=h100a //caffe2:ATen-cu --show-full-output ` I and Nming the .so I do see that the flash symbols are correctly named: ``` 0000000001c3dfb0 t pytorch_flash::run_mha_bwd(pytorch_flash::Flash_bwd_params&, CUstream_st*)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#7}::operator()() const 0000000001c36080 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#2}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#6}::operator()() const 0000000001c360e0 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#2}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#7}::operator()() const 0000000001c35fc0 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#6}::operator()() const 0000000001c36020 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#7}::operator()() const ``` Reviewed By: vkuzo Differential Revision: D68502879 Pulled By: drisspg
|
This pull request was exported from Phabricator. Differential Revision: D68502879 |
9d82be5 to
3e6fe14
Compare
|
This pull request was exported from Phabricator. Differential Revision: D68502879 |
3e6fe14 to
f77f8a9
Compare
|
This pull request was exported from Phabricator. Differential Revision: D68502879 |
…pytorch#146372) Summary: Pull Request resolved: pytorch#146372 Pull Request resolved: pytorch#144120 # Summary ### Sticky points Cuda-graph rng handling has changed / deviated from original implementation. We will be left with a dangling 'offset' val and confusing naming due to BC ## Dependencies - Flash PR: Dao-AILab/flash-attention#1419 ### Other Points - The BC linter is complaining about losing generate.py and its functions which is not real BC surface cc albanD imported-using-ghimport Test Plan: Imported from OSS Building in dev `buck build @//mode/dev-nosan -c fbcode.nvcc_arch=h100a //caffe2:ATen-cu --show-full-output ` I and Nming the .so I do see that the flash symbols are correctly named: ``` 0000000001c3dfb0 t pytorch_flash::run_mha_bwd(pytorch_flash::Flash_bwd_params&, CUstream_st*)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#7}::operator()() const 0000000001c36080 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#2}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#6}::operator()() const 0000000001c360e0 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#2}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#7}::operator()() const 0000000001c35fc0 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#6}::operator()() const 0000000001c36020 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#7}::operator()() const ``` Reviewed By: vkuzo Differential Revision: D68502879 Pulled By: drisspg
f77f8a9 to
4753c31
Compare
|
This pull request was exported from Phabricator. Differential Revision: D68502879 |
4753c31 to
598dd57
Compare
|
This pull request was exported from Phabricator. Differential Revision: D68502879 |
|
This pull request was exported from Phabricator. Differential Revision: D68502879 |
a94fd44 to
435cd14
Compare
|
This pull request was exported from Phabricator. Differential Revision: D68502879 |
…pytorch#146372) Summary: Pull Request resolved: pytorch#146372 Pull Request resolved: pytorch#144120 # Summary ### Sticky points Cuda-graph rng handling has changed / deviated from original implementation. We will be left with a dangling 'offset' val and confusing naming due to BC ## Dependencies - Flash PR: Dao-AILab/flash-attention#1419 ### Other Points - The BC linter is complaining about losing generate.py and its functions which is not real BC surface cc albanD imported-using-ghimport Test Plan: Imported from OSS Building in dev `buck build @//mode/dev-nosan -c fbcode.nvcc_arch=h100a //caffe2:ATen-cu --show-full-output ` I and Nming the .so I do see that the flash symbols are correctly named: ``` 0000000001c3dfb0 t pytorch_flash::run_mha_bwd(pytorch_flash::Flash_bwd_params&, CUstream_st*)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#7}::operator()() const 0000000001c36080 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#2}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#6}::operator()() const 0000000001c360e0 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#2}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#7}::operator()() const 0000000001c35fc0 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#6}::operator()() const 0000000001c36020 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#7}::operator()() const ``` Reviewed By: vkuzo Differential Revision: D68502879 Pulled By: drisspg
435cd14 to
9a51d00
Compare
|
This pull request was exported from Phabricator. Differential Revision: D68502879 |
9a51d00 to
e8dab90
Compare
…pytorch#146372) Summary: Pull Request resolved: pytorch#146372 Pull Request resolved: pytorch#144120 # Summary ### Sticky points Cuda-graph rng handling has changed / deviated from original implementation. We will be left with a dangling 'offset' val and confusing naming due to BC ## Dependencies - Flash PR: Dao-AILab/flash-attention#1419 ### Other Points - The BC linter is complaining about losing generate.py and its functions which is not real BC surface cc albanD imported-using-ghimport Test Plan: Imported from OSS Building in dev `buck build @//mode/dev-nosan -c fbcode.nvcc_arch=h100a //caffe2:ATen-cu --show-full-output ` I and Nming the .so I do see that the flash symbols are correctly named: ``` 0000000001c3dfb0 t pytorch_flash::run_mha_bwd(pytorch_flash::Flash_bwd_params&, CUstream_st*)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#7}::operator()() const 0000000001c36080 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#2}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#6}::operator()() const 0000000001c360e0 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#2}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#7}::operator()() const 0000000001c35fc0 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#6}::operator()() const 0000000001c36020 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#7}::operator()() const ``` Reviewed By: vkuzo Differential Revision: D68502879 Pulled By: drisspg
|
This pull request was exported from Phabricator. Differential Revision: D68502879 |
e8dab90 to
16e81a4
Compare
|
@pytorchbot merge -i (Initiating merge automatically since Phabricator Diff has merged, merging with -i because oss signals were bypassed internally) |
Merge startedYour change will be merged while ignoring the following 2 checks: pull / unstable-linux-focal-cuda12.4-py3.10-gcc9-sm89-xfail / build, pull / linux-jammy-py3.9-gcc11 / test (backwards_compat, 1, 1, linux.2xlarge) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
) ### Summary #146372 changed the op signature of `_scaled_dot_product_flash_attention` and as a consequence DTensor needs to change its sharding defined at https://github.com/pytorch/pytorch/blob/40ad5e01dff05c7d64e070fb01683820e678f788/torch/distributed/tensor/_ops/_matrix_ops.py#L232 ### Test `pytest test/distributed/tensor/test_attention.py` ### Follow-up It's still unclear why the CP unit tests were not run over the original PR which is BC-breaking. Pull Request resolved: #148125 Approved by: https://github.com/tianyu-l, https://github.com/fegin
…#144120) (#146372) Summary: # Summary ### Sticky points Cuda-graph rng handling has changed / deviated from original implementation. We will be left with a dangling 'offset' val and confusing naming due to BC ## Dependencies - Flash PR: Dao-AILab/flash-attention#1419 ### Other Points - The BC linter is complaining about losing generate.py and its functions which is not real BC surface cc albanD imported-using-ghimport Test Plan: Imported from OSS Building in dev `buck build @//mode/dev-nosan -c fbcode.nvcc_arch=h100a //caffe2:ATen-cu --show-full-output ` I and Nming the .so I do see that the flash symbols are correctly named: ``` 0000000001c3dfb0 t pytorch_flash::run_mha_bwd(pytorch_flash::Flash_bwd_params&, CUstream_st*)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()#7}::operator()() const 0000000001c36080 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#2}::operator()() const::{lambda()#1}::operator()() const::{lambda()#6}::operator()() const 0000000001c360e0 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#2}::operator()() const::{lambda()#1}::operator()() const::{lambda()#7}::operator()() const 0000000001c35fc0 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()#6}::operator()() const 0000000001c36020 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()#7}::operator()() const ``` Reviewed By: vkuzo Differential Revision: D68502879 Pulled By: drisspg Pull Request resolved: #146372 Approved by: https://github.com/jbschlosser
pytorch/pytorch#146372 changes the flash attention API. ``` // [Note] BC breaking change to flash seed/offset // Previously: Used separate tensors for philox_seed and philox_offset, sometimes on CPU, sometimes on CUDA // FlashAttention change: Now uses a single uint64_t[2] tensor on device containing both seed and offset // Implementation: Renamed "seed" → "rng_state" (contains both seed+offset) and "offset" → "_unused" ``` ~In nvfuser API, `SdpaFwdOp` now returns a `rng_state` and `SdpaBwdOp` expects a `rng_state`. `_unused` is ignored.~ In nvfuser API, based on torch version, `philox_seed` is now `uint64_t[2]` and philox offset is a empty `uint64_t` tensor on device. I chose to keep the same semantics as PyTorch since it avoids packing/unpacking PyTorch outputs when testing if I were to split the `rng_state` output of PyTorch into `philox_seed` and `philox_offset`. The latter approach would keep the dimensions of both tensors the same, but since we're changing `device` and `dtype`, they will need to be created distinctly for the two cases. This also allows us to switch to the single-parameter version in the future more easily if desired. --------- Co-authored-by: Jingyue Wu <wujingyue@gmail.com>
Summary:
Summary
Sticky points
Cuda-graph rng handling has changed / deviated from original implementation. We will be left with a dangling 'offset' val and confusing naming due to BC
Dependencies
Other Points
cc albanD
imported-using-ghimport
Test Plan:
Imported from OSS
Building in dev
buck build @//mode/dev-nosan -c fbcode.nvcc_arch=h100a //caffe2:ATen-cu --show-full-outputI and Nming the .so I do see that the flash symbols are correctly named:
Reviewed By: vkuzo
Differential Revision: D68502879
Pulled By: drisspg
cc @albanD