Skip to content

fix dtensor and tensor inconsistent compute mesh#153268

Closed
zhe-inflection wants to merge 1 commit intopytorch:mainfrom
zhe-inflection:fix_dtensor_tensor_inconsistent_compute_mesh
Closed

fix dtensor and tensor inconsistent compute mesh#153268
zhe-inflection wants to merge 1 commit intopytorch:mainfrom
zhe-inflection:fix_dtensor_tensor_inconsistent_compute_mesh

Conversation

@zhe-inflection
Copy link

@zhe-inflection zhe-inflection commented May 9, 2025

Issues:
in torch/distributed/tensor/_dispatch.py dispatch function, Dtensor and Tensor in some scenarios have different compute mesh, which end up the crash as the following:

[rank8]: ValueError: Could not run pointwise computation across different mesh: Found DeviceMesh('cuda', [0, 8, 16, 24, 32, 40], mesh_dim_names=('dp_shard',)) and DeviceMesh('cuda', [[0, 1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13, 14, 15], [16, 17, 18, 19, 20, 21, 22, 23], [24, 25, 26, 27, 28, 29, 30, 31], [32, 33, 34, 35, 36, 37, 38, 39], [40, 41, 42, 43, 44, 45, 46, 47]], mesh_dim_names=('dp_shard', 'tp'))!

More detail version on the root cause and the issue:
when using Adam optimizer fused_adam version, in _dispatch.py _dispatch function()
args: {weights, grads, exp_avgs, exp_avgs_sls, max_exp_avg_sqs, step}, total 6 groups.

If we finetune Llama 4 Scout with only one transformer block, there are 16 params within weights and their compute_mesh (Data parallel = 4, and tensor parallel = 2, 1 node with 8 GPU)

torch.Size([202048, 5120]), DeviceMesh('cuda', [0, 2, 4, 6], mesh_dim_names=('dp_shard',))
torch.Size([5120, 5120]), DeviceMesh('cuda', [[0, 1], [2, 3], [4, 5], [6, 7]], mesh_dim_names=('dp_shard', 'tp'))
torch.Size([1024, 5120]), DeviceMesh('cuda', [[0, 1], [2, 3], [4, 5], [6, 7]], mesh_dim_names=('dp_shard', 'tp'))
torch.Size([1024, 5120]), DeviceMesh('cuda', [[0, 1], [2, 3], [4, 5], [6, 7]], mesh_dim_names=('dp_shard', 'tp'))
torch.Size([5120, 5120]), DeviceMesh('cuda', [[0, 1], [2, 3], [4, 5], [6, 7]], mesh_dim_names=('dp_shard', 'tp'))
torch.Size([16, 5120, 8192]), DeviceMesh('cuda', [[0, 1], [2, 3], [4, 5], [6, 7]], mesh_dim_names=('dp_shard', 'tp'))
torch.Size([16, 8192, 5120]), DeviceMesh('cuda', [[0, 1], [2, 3], [4, 5], [6, 7]], mesh_dim_names=('dp_shard', 'tp'))
torch.Size([16, 5120, 8192]), DeviceMesh('cuda', [[0, 1], [2, 3], [4, 5], [6, 7]], mesh_dim_names=('dp_shard', 'tp'))
torch.Size([16, 5120]), DeviceMesh('cuda', [[0, 1], [2, 3], [4, 5], [6, 7]], mesh_dim_names=('dp_shard', 'tp'))
torch.Size([8192, 5120]), DeviceMesh('cuda', [0, 2, 4, 6], mesh_dim_names=('dp_shard',))
torch.Size([5120, 8192]), DeviceMesh('cuda', [0, 2, 4, 6], mesh_dim_names=('dp_shard',))
torch.Size([8192, 5120]), DeviceMesh('cuda', [0, 2, 4, 6], mesh_dim_names=('dp_shard',))
torch.Size([5120]), DeviceMesh('cuda', [[0, 1], [2, 3], [4, 5], [6, 7]], mesh_dim_names=('dp_shard', 'tp'))
torch.Size([5120]), DeviceMesh('cuda', [[0, 1], [2, 3], [4, 5], [6, 7]], mesh_dim_names=('dp_shard', 'tp'))
torch.Size([5120]), DeviceMesh('cuda', [[0, 1], [2, 3], [4, 5], [6, 7]], mesh_dim_names=('dp_shard', 'tp'))
torch.Size([202048, 5120]), DeviceMesh('cuda', [[0, 1], [2, 3], [4, 5], [6, 7]], mesh_dim_names=('dp_shard', 'tp'))

inside _dispatch() function, unwrap_to_op_info() function tries to assign the proper compute_mesh to tensor (step) from dtensor (weights). However, 16 tensor in step group ends up having same compute_mesh as DeviceMesh('cuda', [0, 2, 4, 6], mesh_dim_names=('dp_shard',)), from very first dtensor. This is root cause.

In the new version, we resolve this issue by

  1. recording the compute_mesh for all Dtensor in the first group.
  2. assigning proper compute_mesh for tensor.

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @ezyang @msaroufim

@pytorch-bot
Copy link

pytorch-bot bot commented May 9, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/153268

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label May 9, 2025
@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented May 9, 2025

CLA Signed

The committers listed above are authorized under a signed CLA.

  • ✅ login: zhe-inflection (f34a894)

@colesbury colesbury requested a review from kwen2501 May 13, 2025 00:48
@colesbury colesbury added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 13, 2025
@zhe-inflection
Copy link
Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label May 13, 2025
@zhe-inflection
Copy link
Author

Hi @colesbury, I have signed CLA authorization, it is still showing not signed CLA and blocking, Could you please check on your side? Thanks a lot.

@colesbury
Copy link
Member

@zhe-inflection - you'll have submit an Easy CLA support ticket (see the link in #153268 (comment)). I don't have any control or access over the Easy CLA bot. I think that's managed by the Linux Foundation.

@zhe-inflection
Copy link
Author

Hi @kwen2501 any feedback on this PR? Thanks a lot!

@kwen2501
Copy link
Collaborator

I am not sure how the program looks like. Is it trying to have a computation between a TP-sharded tensor with a non-TP-sharded tensor?
@XilunWu

@kwen2501 kwen2501 requested review from XilunWu, wanchaol and wz337 May 22, 2025 20:50
@kwen2501
Copy link
Collaborator

@wanchaol @wz337 do you mind taking a look?

@kwen2501
Copy link
Collaborator

kwen2501 commented May 22, 2025

In the new version, we resolve this issue by
recording the compute_mesh for all Dtensor in the first group.
assigning proper compute_mesh for tensor.

I don't know what a "group" would mean in general cases.
I am also not sure if a DTensor can be "re-assigned" to a different mesh.

Perhaps you can teach us more by commenting the code change?

@XilunWu
Copy link
Contributor

XilunWu commented May 22, 2025

I think the issue happens when calling a pointwise op on 6 args (as OP said):
{weights, grads, exp_avgs, exp_avgs_sls, max_exp_avg_sqs, step}

Since step is a scalar value, it will be casted to a Replicate DTensor before passing to that op's sharding propagation, at: https://github.com/pytorch/pytorch/blob/6dee4820d010b2af9de1f8c4fe71a89862624ecd/torch/distributed/tensor/_dispatch.py#L361-L365

In sharding propagation, it requires all DTensor args have the same device_mesh at: https://github.com/pytorch/pytorch/blob/6dee4820d010b2af9de1f8c4fe71a89862624ecd/torch/distributed/tensor/_ops/_pointwise_ops.py#L495

Here the followed_strategy corresponds to the DTensor with the most shards (i.e. here it's the 2-D sharded ones). This value in this case is different from compute_mesh which is the first mesh in arg_list, so changing the assignment to compute_mesh in this PR (https://github.com/pytorch/pytorch/blob/6dee4820d010b2af9de1f8c4fe71a89862624ecd/torch/distributed/tensor/_dispatch.py#L360) somehow resolves this case.

I feel the solution here is too hacky. We can discuss how to come up with a better one. cc @kwen2501 @wanchaol @wz337 @tianyu-l

@zhe-inflection
Copy link
Author

Hi @kwen2501 , @XilunWu was right. {weights, grads, exp_avgs, exp_avgs_sls, max_exp_avg_sqs, step} in Adamw algorithm, each one is a group. Definitely 'd like to have a better solution. The challenge part is that not sure args and have to assign proper compute_mesh to tensor (step in the above case) properly. Thanks.

@tianyu-l
Copy link
Contributor

@zhe-inflection @XilunWu
I agree that requiring the mesh from different parameters to be exact the same for fused optimizer is a constraint we should consider lifting.

Meanwhile, I have the following workaround for FSDP vs. FSDP+TP mismatch.
https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/llama4/infra/parallelize_llama.py#L169
https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/llama4/infra/expert_parallel.py#L98

The idea is to put everything on the 2D mesh so that Adam won't complain, which incurs some additional DTensor overhead but should be OK since most of your parameters are already 2D DTensors.

Copy link
Collaborator

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should do it in this way, it's too hacky and might not work for every case.

The underlying problem is that fused_adam op have a group of steps that each of the them corresponds to a different device mesh. I think these kind of problems should be solved in the fused adam op level, not the general dispatch level. @zhe-inflection lmk if you want to work on a correct fix (can discuss more on slack), otherwise I'll try to find sometime to fix this in the next couple of days

@zhe-inflection
Copy link
Author

@wanchaol @tianyu-l thank you guys for replying. For Adam case, I agree these kind of problems should be solved in the fused adam op level, not the general dispatch level. But for other cases (not sure what other cases might be), tensor compute_mesh must be as exactly same as dtensor's compute mesh. More concretely for "step": {step_1, step_2, ..., step_16}, weights: {weights_1, weights_2, ..., weights_16}, after conversion, step_1's compute_mesh should be as same as weights_1's compute_mesh, and step_2's compute_mesh should be as same as weights_2's compute_mesh, ....

Will be happy to work on a correct fix. Let's chat more on slack.

@tianyu-l
Copy link
Contributor

tianyu-l commented Jul 7, 2025

proposed another solution in #157682

pytorchmergebot pushed a commit that referenced this pull request Jul 8, 2025
…iple meshes (#157682)

We are seeing more and more use cases where parameters in a model (under the same optimizer group) are put on different meshes. E.g.
- when FSDP and TP are both applied, some parameters are sharded only on the FSDP mesh but not TP mesh (see #153268).
- in [dp2ep Expert Parallel](pytorch/torchtitan#1324), the routed experts are sharded on the (global FSDP \ EP) mesh for smaller FSDP and on the EP mesh for EP, whereas other params are sharded on the global FSDP mesh for FSDP.

This PR is, in some sense, a continuation of #147869 to tackle the problem when fused optimizers are used. In such cases, the [`fused_adam`](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml#L15786) / `fused_adamw` has a scalar tensor arg `state_steps` which gets automatically cast to DTensor on the default [`compute_mesh`](https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_dispatch.py#L350) (one of the multiple meshes), even though the it could correspond to different meshes.

To avoid hitting the cross-mesh propagation exception in `common_pointwise_strategy` and followup redistribute problems, we manually set the target mesh and placements to be the same as input mesh and placements, so that no redistribute will be triggered. This also helps bypass the situation where [`generate_redistribute_costs`](https://github.com/pytorch/pytorch/pull/157682/files#diff-eea32a36dd2d4e58307bc5229402e48048b2ecaef64a7c085495fba1ee10ac89R597) returns infinite cost due to cross mesh redistribute.

Moreover, this PR has minimal scope (restricted to the `fused_ops`) and doesn't need to modify other files such as `_sharding_prop.py`.

Pull Request resolved: #157682
Approved by: https://github.com/wanchaol
@github-actions
Copy link
Contributor

github-actions bot commented Sep 5, 2025

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Sep 5, 2025
@ezyang ezyang closed this Sep 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

oncall: distributed Add this issue/PR to distributed oncall triage queue open source Stale topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants