Skip to content

[dtensor][fix] fix _scaled_dot_product_flash_attention sharding#148125

Closed
XilunWu wants to merge 2 commits intogh/XilunWu/120/basefrom
gh/XilunWu/120/head
Closed

[dtensor][fix] fix _scaled_dot_product_flash_attention sharding#148125
XilunWu wants to merge 2 commits intogh/XilunWu/120/basefrom
gh/XilunWu/120/head

Conversation

@XilunWu
Copy link
Contributor

@XilunWu XilunWu commented Feb 27, 2025

Stack from ghstack (oldest at bottom):

Summary

#146372 changed the op signature of _scaled_dot_product_flash_attention and as a consequence DTensor needs to change its sharding defined at

def scaled_dot_product_flash_attention_strategy(

Test

pytest test/distributed/tensor/test_attention.py

Follow-up

It's still unclear why the CP unit tests were not run over the original PR which is BC-breaking.

cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @tianyu-l

@pytorch-bot
Copy link

pytorch-bot bot commented Feb 27, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/148125

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 5a8391c with merge base 2978771 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor oncall: distributed Add this issue/PR to distributed oncall triage queue labels Feb 27, 2025
@XilunWu XilunWu marked this pull request as draft February 28, 2025 00:00
cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Feb 28, 2025
@XilunWu XilunWu added better-engineering Relatively self-contained tasks for better engineering contributors module: dtensor distributed tensor tag module: context parallel PyTorch Context Parallel labels Feb 28, 2025
@XilunWu XilunWu changed the title [dtensor] fix scaled dot product flash attention sharding [dtensor][fix] fix _scaled_dot_product_flash_attention sharding Feb 28, 2025
@XilunWu XilunWu marked this pull request as ready for review February 28, 2025 00:46
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@XilunWu XilunWu added the ciflow/trunk Trigger trunk jobs on your pull request label Feb 28, 2025
@XilunWu
Copy link
Contributor Author

XilunWu commented Feb 28, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

fegin pushed a commit to pytorch/torchtitan that referenced this pull request Mar 3, 2025
…as been fixed (#912)

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* __->__ #912

### Summary
This PR undo #898 and
re-enables CP tests in CI as
pytorch/pytorch#148125 fixed the DTensor sdp
flash attention op.

### Test
CI
fegin added a commit to pytorch/torchtitan that referenced this pull request Mar 3, 2025
#921)

…as been fixed (#912)

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* __->__ #912

### Summary
This PR undo #898 and
re-enables CP tests in CI as
pytorch/pytorch#148125 fixed the DTensor sdp
flash attention op.

### Test
CI

Co-authored-by: Xilun Wu <12968408+XilunWu@users.noreply.github.com>
@XilunWu XilunWu mentioned this pull request Mar 3, 2025
@github-actions github-actions bot deleted the gh/XilunWu/120/head branch March 31, 2025 02:14
MaxiBoether pushed a commit to eth-easl/torchtitan-mixtera that referenced this pull request Apr 17, 2025
pytorch#921)

…as been fixed (pytorch#912)

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* __->__ pytorch#912

### Summary
This PR undo pytorch#898 and
re-enables CP tests in CI as
pytorch/pytorch#148125 fixed the DTensor sdp
flash attention op.

### Test
CI

Co-authored-by: Xilun Wu <12968408+XilunWu@users.noreply.github.com>
xrsrke pushed a commit to NousResearch/torchtitan that referenced this pull request Feb 13, 2026
pytorch#921)

…as been fixed (pytorch#912)

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* __->__ pytorch#912

### Summary
This PR undo pytorch#898 and
re-enables CP tests in CI as
pytorch/pytorch#148125 fixed the DTensor sdp
flash attention op.

### Test
CI

Co-authored-by: Xilun Wu <12968408+XilunWu@users.noreply.github.com>
xrsrke pushed a commit to NousResearch/torchtitan that referenced this pull request Feb 25, 2026
pytorch#921)

…as been fixed (pytorch#912)

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* __->__ pytorch#912

### Summary
This PR undo pytorch#898 and
re-enables CP tests in CI as
pytorch/pytorch#148125 fixed the DTensor sdp
flash attention op.

### Test
CI

Co-authored-by: Xilun Wu <12968408+XilunWu@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

better-engineering Relatively self-contained tasks for better engineering contributors ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: context parallel PyTorch Context Parallel module: dtensor distributed tensor tag oncall: distributed Add this issue/PR to distributed oncall triage queue topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants