fix dtensor and tensor inconsistent compute mesh#153268
fix dtensor and tensor inconsistent compute mesh#153268zhe-inflection wants to merge 1 commit intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/153268
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
|
|
@pytorchbot label "topic: not user facing" |
|
Hi @colesbury, I have signed CLA authorization, it is still showing not signed CLA and blocking, Could you please check on your side? Thanks a lot. |
|
@zhe-inflection - you'll have submit an Easy CLA support ticket (see the link in #153268 (comment)). I don't have any control or access over the Easy CLA bot. I think that's managed by the Linux Foundation. |
|
Hi @kwen2501 any feedback on this PR? Thanks a lot! |
|
I am not sure how the program looks like. Is it trying to have a computation between a TP-sharded tensor with a non-TP-sharded tensor? |
I don't know what a "group" would mean in general cases. Perhaps you can teach us more by commenting the code change? |
|
I think the issue happens when calling a pointwise op on 6 args (as OP said): Since In sharding propagation, it requires all DTensor args have the same device_mesh at: https://github.com/pytorch/pytorch/blob/6dee4820d010b2af9de1f8c4fe71a89862624ecd/torch/distributed/tensor/_ops/_pointwise_ops.py#L495 Here the I feel the solution here is too hacky. We can discuss how to come up with a better one. cc @kwen2501 @wanchaol @wz337 @tianyu-l |
|
Hi @kwen2501 , @XilunWu was right. {weights, grads, exp_avgs, exp_avgs_sls, max_exp_avg_sqs, step} in Adamw algorithm, each one is a group. Definitely 'd like to have a better solution. The challenge part is that not sure args and have to assign proper compute_mesh to tensor (step in the above case) properly. Thanks. |
|
@zhe-inflection @XilunWu Meanwhile, I have the following workaround for FSDP vs. FSDP+TP mismatch. The idea is to put everything on the 2D mesh so that Adam won't complain, which incurs some additional DTensor overhead but should be OK since most of your parameters are already 2D DTensors. |
wanchaol
left a comment
There was a problem hiding this comment.
I don't think we should do it in this way, it's too hacky and might not work for every case.
The underlying problem is that fused_adam op have a group of steps that each of the them corresponds to a different device mesh. I think these kind of problems should be solved in the fused adam op level, not the general dispatch level. @zhe-inflection lmk if you want to work on a correct fix (can discuss more on slack), otherwise I'll try to find sometime to fix this in the next couple of days
|
@wanchaol @tianyu-l thank you guys for replying. For Adam case, I agree these kind of problems should be solved in the fused adam op level, not the general dispatch level. But for other cases (not sure what other cases might be), tensor compute_mesh must be as exactly same as dtensor's compute mesh. More concretely for "step": {step_1, step_2, ..., step_16}, weights: {weights_1, weights_2, ..., weights_16}, after conversion, step_1's compute_mesh should be as same as weights_1's compute_mesh, and step_2's compute_mesh should be as same as weights_2's compute_mesh, .... Will be happy to work on a correct fix. Let's chat more on slack. |
6dee482 to
f34a894
Compare
|
proposed another solution in #157682 |
…iple meshes (#157682) We are seeing more and more use cases where parameters in a model (under the same optimizer group) are put on different meshes. E.g. - when FSDP and TP are both applied, some parameters are sharded only on the FSDP mesh but not TP mesh (see #153268). - in [dp2ep Expert Parallel](pytorch/torchtitan#1324), the routed experts are sharded on the (global FSDP \ EP) mesh for smaller FSDP and on the EP mesh for EP, whereas other params are sharded on the global FSDP mesh for FSDP. This PR is, in some sense, a continuation of #147869 to tackle the problem when fused optimizers are used. In such cases, the [`fused_adam`](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml#L15786) / `fused_adamw` has a scalar tensor arg `state_steps` which gets automatically cast to DTensor on the default [`compute_mesh`](https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_dispatch.py#L350) (one of the multiple meshes), even though the it could correspond to different meshes. To avoid hitting the cross-mesh propagation exception in `common_pointwise_strategy` and followup redistribute problems, we manually set the target mesh and placements to be the same as input mesh and placements, so that no redistribute will be triggered. This also helps bypass the situation where [`generate_redistribute_costs`](https://github.com/pytorch/pytorch/pull/157682/files#diff-eea32a36dd2d4e58307bc5229402e48048b2ecaef64a7c085495fba1ee10ac89R597) returns infinite cost due to cross mesh redistribute. Moreover, this PR has minimal scope (restricted to the `fused_ops`) and doesn't need to modify other files such as `_sharding_prop.py`. Pull Request resolved: #157682 Approved by: https://github.com/wanchaol
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Issues:
in torch/distributed/tensor/_dispatch.py dispatch function, Dtensor and Tensor in some scenarios have different compute mesh, which end up the crash as the following:
[rank8]: ValueError: Could not run pointwise computation across different mesh: Found DeviceMesh('cuda', [0, 8, 16, 24, 32, 40], mesh_dim_names=('dp_shard',)) and DeviceMesh('cuda', [[0, 1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13, 14, 15], [16, 17, 18, 19, 20, 21, 22, 23], [24, 25, 26, 27, 28, 29, 30, 31], [32, 33, 34, 35, 36, 37, 38, 39], [40, 41, 42, 43, 44, 45, 46, 47]], mesh_dim_names=('dp_shard', 'tp'))!More detail version on the root cause and the issue:
when using Adam optimizer fused_adam version, in _dispatch.py _dispatch function()
args: {weights, grads, exp_avgs, exp_avgs_sls, max_exp_avg_sqs, step}, total 6 groups.
If we finetune Llama 4 Scout with only one transformer block, there are 16 params within weights and their compute_mesh (Data parallel = 4, and tensor parallel = 2, 1 node with 8 GPU)
inside _dispatch() function, unwrap_to_op_info() function tries to assign the proper compute_mesh to tensor (step) from dtensor (weights). However, 16 tensor in step group ends up having same compute_mesh as DeviceMesh('cuda', [0, 2, 4, 6], mesh_dim_names=('dp_shard',)), from very first dtensor. This is root cause.
In the new version, we resolve this issue by
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @ezyang @msaroufim