Skip to content

[FlexAttention] Enable different qk and v head-dims#134043

Closed
drisspg wants to merge 14 commits intogh/drisspg/36/basefrom
gh/drisspg/36/head
Closed

[FlexAttention] Enable different qk and v head-dims#134043
drisspg wants to merge 14 commits intogh/drisspg/36/basefrom
gh/drisspg/36/head

Conversation

@drisspg
Copy link
Copy Markdown
Contributor

@drisspg drisspg commented Aug 20, 2024

Stack from ghstack (oldest at bottom):

Summary

Adds the option for the head dims to be different between QK and V tensors.

Fixes issue: #133674

V_DIM > QK_DIM is blocked by landing: triton-lang/triton#4138 / triton-lang/triton#4540

Into PyTorch's triton branch.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Aug 20, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/134043

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 067c895 with merge base 5f3d22a (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

# Summary
Adds the option for the head dims to be different between QK and V tensors.

Local testing shows that when QK_HEAD_DIM  > V head dim this works great for the forward
Not when V > QK , still debugging

[ghstack-poisoned]
# Summary
Adds the option for the head dims to be different between QK and V tensors.

Local testing shows that when QK_HEAD_DIM  > V head dim this works great for the forward
Not when V > QK , still debugging

[ghstack-poisoned]
# Summary
Adds the option for the head dims to be different between QK and V tensors.

Local testing shows that when QK_HEAD_DIM  > V head dim this works great for the forward
Not when V > QK , still debugging

[ghstack-poisoned]
# Summary
Adds the option for the head dims to be different between QK and V tensors.

Fixes issue: #133674

[ghstack-poisoned]
# Summary
Adds the option for the head dims to be different between QK and V tensors.

Fixes issue: #133674

[ghstack-poisoned]

q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :]

Q_block_ptr = tl.make_block_ptr(
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wasnt being used

@drisspg drisspg requested review from Chillee and yanboliang and removed request for albanD, jbschlosser and mikaylagawarecki August 20, 2024 23:21
# Summary
Adds the option for the head dims to be different between QK and V tensors.

Fixes issue: #133674

[ghstack-poisoned]
# Summary
Adds the option for the head dims to be different between QK and V tensors.

Fixes issue: #133674

[ghstack-poisoned]
@drisspg
Copy link
Copy Markdown
Contributor Author

drisspg commented Aug 20, 2024

@pytorchbot merge

@yanboliang
Copy link
Copy Markdown
Contributor

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@yanboliang
Copy link
Copy Markdown
Contributor

@pytorchbot merge -f "stucked ROCM jobs, flex attention unit tests only on CUDA"

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@yanboliang
Copy link
Copy Markdown
Contributor

@pytorchbot merge -f "stucked ROCM jobs, flex attention unit tests only on CUDA"

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@jeanschmidt
Copy link
Copy Markdown
Contributor

@pytorchbot revert -m "Need to revert, in order to be able to revert #133373, feel free to reland this after solving conflicts" -c ghfirst

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request Aug 22, 2024
This reverts commit e847b6b.

Reverted #134043 on behalf of https://github.com/jeanschmidt due to Need to revert, in order to be able to revert #133373, feel free to reland this after solving conflicts ([comment](#134043 (comment)))
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@drisspg your PR has been successfully reverted.

@yanboliang
Copy link
Copy Markdown
Contributor

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit to mayank31398/pytorch that referenced this pull request Aug 23, 2024
# Summary
Adds the option for the head dims to be different between QK and V tensors.

Fixes issue: pytorch#133674

V_DIM > QK_DIM is blocked by landing: triton-lang/triton#4138 / triton-lang/triton#4540

Into PyTorch's triton branch.

Pull Request resolved: pytorch#134043
Approved by: https://github.com/Chillee
pytorchmergebot added a commit to mayank31398/pytorch that referenced this pull request Aug 23, 2024
…134043)"

This reverts commit e847b6b.

Reverted pytorch#134043 on behalf of https://github.com/jeanschmidt due to Need to revert, in order to be able to revert pytorch#133373, feel free to reland this after solving conflicts ([comment](pytorch#134043 (comment)))
@github-actions github-actions bot deleted the gh/drisspg/36/head branch October 1, 2024 02:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants