Skip to content

Add Flash Attention support to FlexAttention#161118

Closed
drisspg wants to merge 34 commits intogh/drisspg/187/basefrom
gh/drisspg/187/head
Closed

Add Flash Attention support to FlexAttention#161118
drisspg wants to merge 34 commits intogh/drisspg/187/basefrom
gh/drisspg/187/head

Conversation

@drisspg
Copy link
Contributor

@drisspg drisspg commented Aug 21, 2025

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Aug 21, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/161118

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit a092ae4 with merge base 086dec3 (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

drisspg added a commit that referenced this pull request Aug 21, 2025
@drisspg drisspg marked this pull request as draft August 21, 2025 00:08
[ghstack-poisoned]
drisspg added a commit that referenced this pull request Aug 21, 2025
[ghstack-poisoned]
drisspg added a commit that referenced this pull request Aug 21, 2025
[ghstack-poisoned]
drisspg added a commit that referenced this pull request Aug 21, 2025
[ghstack-poisoned]
drisspg added a commit that referenced this pull request Aug 21, 2025
drisspg added a commit that referenced this pull request Aug 21, 2025
drisspg added a commit that referenced this pull request Aug 21, 2025
drisspg added a commit that referenced this pull request Aug 21, 2025
drisspg added a commit that referenced this pull request Aug 21, 2025
drisspg added a commit that referenced this pull request Aug 21, 2025
drisspg added a commit that referenced this pull request Aug 21, 2025
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@drisspg drisspg requested review from Chillee and removed request for albanD, jbschlosser and mikaylagawarecki September 20, 2025 16:32
[ghstack-poisoned]
[ghstack-poisoned]
@drisspg drisspg marked this pull request as ready for review October 7, 2025 03:28
[ghstack-poisoned]
@drisspg drisspg requested a review from v0i0 October 8, 2025 17:47
[ghstack-poisoned]
[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #162031

[ghstack-poisoned]
[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #162031

pytorchmergebot pushed a commit that referenced this pull request Oct 10, 2025
## TODO
Check on multi indices
```Python

    @cute.jit
    def score_mod(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers):
        in_ptr4 = buffers[0]
        tmp0 = tSrS_ssa
        tmp1 = b_idx
        tmp2 = h_idx
        tmp3 = cute.make_fragment(1, cutlass.Int32)
        tmp4 = tmp3.store(32*tmp1 + tmp2)
        tmp5 = cute.make_fragment(1, cutlass.BFloat16)
        tmp6 = tmp3[0]
        tmp7 = tmp5[0] = (in_ptr4[tmp6])
        tmp8 = (tmp5.load()).to(cutlass.Float32)
        tmp9 = (tmp0 + tmp8)
        tSrS_ssa = tmp9

        return tSrS_ssa

 ```

I dont think that
```
        tmp4 = tmp3.store(32*tmp1 + tmp2)
        tmp5 = cute.make_fragment(1, cutlass.BFloat16)
        tmp6 = tmp3[0]
        tmp7 = tmp5[0] = (in_ptr4[tmp6]

```

 is right since this tmp6 value will be larger than the actual index dim int his case its B -> see if its possible to 1d index

Pull Request resolved: #162031
Approved by: https://github.com/v0i0
ghstack dependencies: #161118
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
## TODO
Check on multi indices
```Python

    @cute.jit
    def score_mod(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, buffers):
        in_ptr4 = buffers[0]
        tmp0 = tSrS_ssa
        tmp1 = b_idx
        tmp2 = h_idx
        tmp3 = cute.make_fragment(1, cutlass.Int32)
        tmp4 = tmp3.store(32*tmp1 + tmp2)
        tmp5 = cute.make_fragment(1, cutlass.BFloat16)
        tmp6 = tmp3[0]
        tmp7 = tmp5[0] = (in_ptr4[tmp6])
        tmp8 = (tmp5.load()).to(cutlass.Float32)
        tmp9 = (tmp0 + tmp8)
        tSrS_ssa = tmp9

        return tSrS_ssa

 ```

I dont think that
```
        tmp4 = tmp3.store(32*tmp1 + tmp2)
        tmp5 = cute.make_fragment(1, cutlass.BFloat16)
        tmp6 = tmp3[0]
        tmp7 = tmp5[0] = (in_ptr4[tmp6]

```

 is right since this tmp6 value will be larger than the actual index dim int his case its B -> see if its possible to 1d index

Pull Request resolved: pytorch#162031
Approved by: https://github.com/v0i0
ghstack dependencies: pytorch#161118
@github-actions github-actions bot deleted the gh/drisspg/187/head branch November 9, 2025 02:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants