Skip to content

[dtensor] refactor sharding prop to handle cross mesh computation#147869

Closed
wanchaol wants to merge 6 commits intomainfrom
mesh_check
Closed

[dtensor] refactor sharding prop to handle cross mesh computation#147869
wanchaol wants to merge 6 commits intomainfrom
mesh_check

Conversation

@wanchaol
Copy link
Collaborator

@wanchaol wanchaol commented Feb 25, 2025

as titled, this PR moves the same mesh check from the sharding propagation level to each individual operator level.

This is to allow more flexibility for each individual operator to check the operator can be run on the same mesh or not. For example, before this PR if user have two DTensor params that lives on different DeviceMesh, and want to run for_each operator on them individually, it would error out with cross mesh error. But for foreach computation there could be DTensors that live on different meshes, as long as the the mesh are the same in a "zipped way".

This should also fix #134212

Fixes #ISSUE_NUMBER

cc @H-Huang @awgu @kwen2501 @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

as titled, this PR moves the same mesh check from the sharding
propagation level to each individual operator level.

This is to allow more flexibility for each individual operator to check
the operator can be run on the same mesh or not. For example, before this
PR if user have two DTensor params that lives on different DeviceMesh,
and want to run `for_each` operator on them individually, it would error
out with cross mesh error. But for foreach computation there could be
DTensors that live on different meshes, as long as the the mesh are the
same in a "zipped way".

This should also fix #134212
@wanchaol wanchaol requested review from awgu, tianyu-l and wz337 February 25, 2025 18:13
@pytorch-bot
Copy link

pytorch-bot bot commented Feb 25, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/147869

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit cb76160 with merge base 4995e05 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor oncall: distributed Add this issue/PR to distributed oncall triage queue labels Feb 25, 2025
@wanchaol wanchaol added the release notes: distributed (dtensor) release notes category label Feb 25, 2025
@wanchaol wanchaol added the ciflow/trunk Trigger trunk jobs on your pull request label Feb 28, 2025
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great. Left some comments.
I didn't check if a certain op needs validate or not. Let me know if you need me to.

@janeyx99 janeyx99 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 3, 2025
@wanchaol wanchaol requested a review from tianyu-l March 3, 2025 21:41
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

@wanchaol
Copy link
Collaborator Author

wanchaol commented Mar 4, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Jul 8, 2025
…iple meshes (#157682)

We are seeing more and more use cases where parameters in a model (under the same optimizer group) are put on different meshes. E.g.
- when FSDP and TP are both applied, some parameters are sharded only on the FSDP mesh but not TP mesh (see #153268).
- in [dp2ep Expert Parallel](pytorch/torchtitan#1324), the routed experts are sharded on the (global FSDP \ EP) mesh for smaller FSDP and on the EP mesh for EP, whereas other params are sharded on the global FSDP mesh for FSDP.

This PR is, in some sense, a continuation of #147869 to tackle the problem when fused optimizers are used. In such cases, the [`fused_adam`](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml#L15786) / `fused_adamw` has a scalar tensor arg `state_steps` which gets automatically cast to DTensor on the default [`compute_mesh`](https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_dispatch.py#L350) (one of the multiple meshes), even though the it could correspond to different meshes.

To avoid hitting the cross-mesh propagation exception in `common_pointwise_strategy` and followup redistribute problems, we manually set the target mesh and placements to be the same as input mesh and placements, so that no redistribute will be triggered. This also helps bypass the situation where [`generate_redistribute_costs`](https://github.com/pytorch/pytorch/pull/157682/files#diff-eea32a36dd2d4e58307bc5229402e48048b2ecaef64a7c085495fba1ee10ac89R597) returns infinite cost due to cross mesh redistribute.

Moreover, this PR has minimal scope (restricted to the `fused_ops`) and doesn't need to modify other files such as `_sharding_prop.py`.

Pull Request resolved: #157682
Approved by: https://github.com/wanchaol
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue open source release notes: distributed (dtensor) release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[DSD] Test could fail in test_fsdp_dsd.py

5 participants