Skip to content

[Submodule] Turning flash-attention integration into 3rd party submod#144120

Closed
drisspg wants to merge 22 commits intogh/drisspg/111/basefrom
gh/drisspg/111/head
Closed

[Submodule] Turning flash-attention integration into 3rd party submod#144120
drisspg wants to merge 22 commits intogh/drisspg/111/basefrom
gh/drisspg/111/head

Conversation

@drisspg
Copy link
Contributor

@drisspg drisspg commented Jan 3, 2025

Stack from ghstack (oldest at bottom):

Summary

Sticky points

Cuda-graph rng handling has changed / deviated from original implementation. We will be left with a dangling 'offset' val and confusing naming due to BC

Dependencies

Other Points

  • The BC linter is complaining about losing generate.py and its functions which is not real BC surface
    cc @albanD

Differential Revision: D68502879

[ghstack-poisoned]
@drisspg drisspg mentioned this pull request Jan 3, 2025
@pytorch-bot
Copy link

pytorch-bot bot commented Jan 3, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/144120

Note: Links to docs will display an error until the docs builds have been completed.

❌ 14 New Failures, 8 Unrelated Failures

As of commit 3ab6395 with merge base 40e27fb (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@drisspg drisspg marked this pull request as draft January 3, 2025 00:40
[ghstack-poisoned]
drisspg added a commit that referenced this pull request Jan 3, 2025
ghstack-source-id: 0494ca4
Pull Request resolved: #144120
@drisspg drisspg added topic: not user facing topic category module: sdpa All things related to torch.nn.functional.scaled_dot_product_attentiion labels Jan 3, 2025
@drisspg drisspg changed the title Trying to reduce flash-deps [Submodule] Turning flash-attention integration into 3rd party submod Jan 7, 2025
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
drisspg added a commit that referenced this pull request Jan 7, 2025
ghstack-source-id: ab6ce91
Pull Request resolved: #144120
[ghstack-poisoned]
drisspg added a commit that referenced this pull request Jan 7, 2025
ghstack-source-id: 4d9655d
Pull Request resolved: #144120
[ghstack-poisoned]
[ghstack-poisoned]
@pytorch-bot pytorch-bot bot had a problem deploying to upload-benchmark-results January 22, 2025 00:13 Failure
@pytorch-bot pytorch-bot bot had a problem deploying to upload-benchmark-results January 22, 2025 00:13 Failure
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 22, 2025 00:13 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 22, 2025 00:13 Inactive
@pytorch-bot pytorch-bot bot had a problem deploying to upload-benchmark-results January 22, 2025 00:13 Failure
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 22, 2025 00:13 Inactive
@pytorch-bot pytorch-bot bot had a problem deploying to upload-benchmark-results January 22, 2025 00:13 Failure
@pytorch-bot pytorch-bot bot had a problem deploying to upload-benchmark-results January 22, 2025 00:13 Failure
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 22, 2025 00:13 Inactive
[ghstack-poisoned]
drisspg added a commit that referenced this pull request Jan 22, 2025
ghstack-source-id: 393f416
Pull Request resolved: #144120
@drisspg
Copy link
Contributor Author

drisspg commented Jan 22, 2025

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 22, 2025 18:35 Inactive
@pytorch-bot pytorch-bot bot had a problem deploying to upload-benchmark-results January 22, 2025 18:35 Failure
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 22, 2025 18:35 Inactive
@pytorch-bot pytorch-bot bot had a problem deploying to upload-benchmark-results January 22, 2025 18:36 Failure
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 22, 2025 18:36 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 22, 2025 18:38 Inactive
@pytorch-bot pytorch-bot bot had a problem deploying to upload-benchmark-results January 22, 2025 18:38 Failure
@pytorch-bot pytorch-bot bot had a problem deploying to upload-benchmark-results January 22, 2025 18:38 Failure
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 22, 2025 18:38 Inactive
@pytorch-bot pytorch-bot bot had a problem deploying to upload-benchmark-results January 22, 2025 18:38 Failure
@pytorch-bot pytorch-bot bot had a problem deploying to upload-benchmark-results January 22, 2025 18:38 Failure
pytorchmergebot pushed a commit that referenced this pull request Jan 24, 2025
…145502)

# Context

Prototyped here: #144120, we are going to make flash-attention a 3rd party submodule. We will then use the c++ sources and include into our build of libtorch.so

This requires various changes to work including external and internal changes. Since these require internal changes we need to co-dev and in the co-dev environment I haven't found a way to sync submodule changes + internal only changes.

This is unused for now

Pull Request resolved: #145502
Approved by: https://github.com/Skylion007
# feature by default We dont currently document this feature because we don't
# Suspect users building from source will need this
add_definitions(-DFLASHATTENTION_DISABLE_ALIBI)
add_definitions(-DFLASHATTENTION_DISABLE_SOFTCAP)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move these here:

target_compile_definitions(torch_cuda PRIVATE USE_MEM_EFF_ATTENTION)

pytorch-bot bot pushed a commit that referenced this pull request Feb 4, 2025
…#144120)

Summary:
Pull Request resolved: #144120

# Summary

### Sticky points

Cuda-graph rng handling has changed / deviated from original implementation. We will be left with a dangling 'offset' val and confusing naming due to BC

## Dependencies
- Flash PR: Dao-AILab/flash-attention#1419

### Other Points
- The BC linter is complaining about losing generate.py and its functions which is not real BC surface
cc albanD

imported-using-ghimport

Test Plan:
Imported from OSS

Building in dev
`buck build @//mode/dev-nosan -c fbcode.nvcc_arch=h100a  //caffe2:ATen-cu --show-full-output    `

I and Nming the .so I do see that the flash symbols are correctly named:
```
0000000001c3dfb0 t pytorch_flash::run_mha_bwd(pytorch_flash::Flash_bwd_params&, CUstream_st*)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()#7}::operator()() const
0000000001c36080 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#2}::operator()() const::{lambda()#1}::operator()() const::{lambda()#6}::operator()() const
0000000001c360e0 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#2}::operator()() const::{lambda()#1}::operator()() const::{lambda()#7}::operator()() const
0000000001c35fc0 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()#6}::operator()() const
0000000001c36020 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()#7}::operator()() const
```

Reviewed By: vkuzo

Differential Revision: D68502879

Pulled By: drisspg
drisspg added a commit to drisspg/pytorch that referenced this pull request Feb 5, 2025
…pytorch#146372)

Summary:
Pull Request resolved: pytorch#146372

Pull Request resolved: pytorch#144120

# Summary

### Sticky points

Cuda-graph rng handling has changed / deviated from original implementation. We will be left with a dangling 'offset' val and confusing naming due to BC

## Dependencies
- Flash PR: Dao-AILab/flash-attention#1419

### Other Points
- The BC linter is complaining about losing generate.py and its functions which is not real BC surface
cc albanD

imported-using-ghimport

Test Plan:
Imported from OSS

Building in dev
`buck build @//mode/dev-nosan -c fbcode.nvcc_arch=h100a  //caffe2:ATen-cu --show-full-output    `

I and Nming the .so I do see that the flash symbols are correctly named:
```
0000000001c3dfb0 t pytorch_flash::run_mha_bwd(pytorch_flash::Flash_bwd_params&, CUstream_st*)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#7}::operator()() const
0000000001c36080 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#2}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#6}::operator()() const
0000000001c360e0 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#2}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#7}::operator()() const
0000000001c35fc0 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#6}::operator()() const
0000000001c36020 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()pytorch#7}::operator()() const
```

Reviewed By: vkuzo

Differential Revision: D68502879

Pulled By: drisspg
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/inductor ciflow/rocm Trigger "default" config CI on ROCm ciflow/trunk Trigger trunk jobs on your pull request module: sdpa All things related to torch.nn.functional.scaled_dot_product_attentiion skip-pr-sanity-checks suppress-bc-linter Suppresses the failures of API backward-compatibility linter (Lint/bc_linter) topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants