[DTensor] single dim fix inplace op expansion#172477
[DTensor] single dim fix inplace op expansion#172477wconstab wants to merge 5 commits intogh/wconstab/499/basefrom
Conversation
This enables the inplace filtering logic that skips strategies with incompatible input placements. Previously, inplace ops were able to generate expanded strategies incompatible with the current input placements, which aren't allowed to be modified by redistribution. [ghstack-poisoned]
This enables the inplace filtering logic that skips strategies with incompatible input placements. Previously, inplace ops were able to generate expanded strategies incompatible with the current input placements, which aren't allowed to be modified by redistribution. [ghstack-poisoned]
|
|
||
| **Left-to-Right Evaluation Order** | ||
|
|
||
| DTensor evaluates Partial placements in **left-to-right order** (i.e., mesh dimension 0 first, |
There was a problem hiding this comment.
I checked the redistribution code, left-to-right order is not completely correct.
The Partial-> Replicate can be right-to-left order:
pytorch/torch/distributed/tensor/_redistribute.py
Lines 749 to 757 in ee562d9
or left-to-right:
pytorch/torch/distributed/tensor/_redistribute.py
Lines 765 to 773 in ee562d9
But in general, it should be either one.
There was a problem hiding this comment.
As long as we have Shard in src placement, we can trigger right-to-left:
from torch.utils._debug_mode import DebugMode
class DistributeWithPartialTest(DTensorTestBase):
@property
def world_size(self) -> int:
return 8
def _extract_redistribute_trace_from_debug_mode(self, s: str) -> str:
import re
match = re.search(r"trace:\s*(.*)\)", s)
if match:
trace_str = match.group(1)
return trace_str
else:
return ""
@with_comms
def test_tmp(self):
mesh = init_device_mesh(self.device_type, (2, 2, 2))
input_data = torch.randn((8, 8, 8), device=self.device_type)
dt = DTensor.from_local(input_data, mesh, (Partial('sum'), Shard(0), Partial('max')))
with DebugMode(record_torchfunction=False) as debug_mode:
dt2 = dt.redistribute(mesh, (Replicate(), Replicate(), Replicate()))
trace_str = self._extract_redistribute_trace_from_debug_mode(debug_mode.debug_string())
print(trace_str)The redistribution path is: P(sum)S(0)P(max)->P(sum)S(0)R->P(sum)RR->RRR. This made me think that there is a bug with the greedy redistribution algorithm when handling Partial with variant. We can't arbitrarily switch between left-to-right or right-to-left order. I think right-to-left should be the correct order to handle Partial->Replicate.
There was a problem hiding this comment.
That sounds good to me. However don't move too fast on it, we are discussing just banning mixed partial types existing in the same placement and then we don't need to define an
order
There was a problem hiding this comment.
oh, this PR was not even supposed to include the changes to partial order. it was a rebasing mistake. I will remove that code. I have another PR up now anyway that partially bans mixed partials. If people are OK with that direction, i'll extend that and ensure we don't allow mixed partials at all parts of the DTensor stack. then this ordering stuff becomes unnecessary.
There was a problem hiding this comment.
Yes, "banning mixed Partial" sounds good to me!
| if isinstance(placement, Partial): | ||
| src_partial_ops[mesh_dim] = placement.reduce_op | ||
|
|
||
| if len(src_partial_ops) > 1: |
There was a problem hiding this comment.
Aha, I see. I left the comment in README file too early. This PR is enforcing the left-to-right order when there exists multiple Partial variants.
|
|
||
| # If some partials are being reduced while others are kept, | ||
| # we need to reduce ALL partials first, then re-partition | ||
| if dst_partial_dims and dst_partial_dims != src_partial_dims: |
There was a problem hiding this comment.
If the condition dst_partial_dims and dst_partial_dims != src_partial_dims doesn't hold, then we will use the default greedy code to handle partial, which can either be left-to-right or reverse. We also need to enforce that order.
This enables the inplace filtering logic that skips strategies with incompatible input placements. Previously, inplace ops were able to generate expanded strategies incompatible with the current input placements, which aren't allowed to be modified by redistribution. [ghstack-poisoned]
This enables the inplace filtering logic that skips strategies with incompatible input placements. Previously, inplace ops were able to generate expanded strategies incompatible with the current input placements, which aren't allowed to be modified by redistribution. ghstack-source-id: 0a32d7e Pull Request resolved: pytorch/pytorch#172477
This enables the inplace filtering logic that skips strategies with incompatible input placements. Previously, inplace ops were able to generate expanded strategies incompatible with the current input placements, which aren't allowed to be modified by redistribution. ghstack-source-id: a1532d7 Pull Request resolved: pytorch/pytorch#172477
This enables the inplace filtering logic that skips strategies with incompatible input placements. Previously, inplace ops were able to generate expanded strategies incompatible with the current input placements, which aren't allowed to be modified by redistribution. [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
This enables the inplace filtering logic that skips strategies with incompatible input placements. Previously, inplace ops were able to generate expanded strategies incompatible with the current input placements, which aren't allowed to be modified by redistribution. [ghstack-poisoned]
|
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 8 checks: trunk / linux-jammy-rocm-py3.10 / test (default, 3, 6, linux.rocm.gpu.gfx942.1), trunk / linux-jammy-rocm-py3.10 / test (default, 2, 6, linux.rocm.gpu.gfx942.1), trunk / linux-jammy-rocm-py3.10 / test (default, 5, 6, linux.rocm.gpu.gfx942.1), trunk / linux-jammy-rocm-py3.10 / test (default, 6, 6, linux.rocm.gpu.gfx942.1), trunk / linux-jammy-rocm-py3.10 / test (default, 1, 6, linux.rocm.gpu.gfx942.1), trunk / linux-jammy-rocm-py3.10 / test (distributed, 1, 3, linux.rocm.gpu.gfx942.4), trunk / linux-jammy-rocm-py3.10 / test (distributed, 3, 3, linux.rocm.gpu.gfx942.4), trunk / linux-jammy-rocm-py3.10 / test (distributed, 2, 3, linux.rocm.gpu.gfx942.4) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / linux-jammy-rocm-py3.10 / test (default, 4, 6, linux.rocm.gpu.gfx942.1) Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge -i |
This enables the inplace filtering logic that skips strategies with incompatible input placements. Previously, inplace ops were able to generate expanded strategies incompatible with the current input placements, which aren't allowed to be modified by redistribution. ghstack-source-id: 7f3bdea Pull Request resolved: pytorch/pytorch#172477
This enables the inplace filtering logic that skips strategies with incompatible input placements. Previously, inplace ops were able to generate expanded strategies incompatible with the current input placements, which aren't allowed to be modified by redistribution. ghstack-source-id: 1a85756 Pull Request resolved: pytorch/pytorch#172477
Stack from ghstack (oldest at bottom):
This enables the inplace filtering logic that skips strategies with
incompatible input placements.
Previously, inplace ops were able to generate expanded strategies
incompatible with the current input placements, which aren't allowed to
be modified by redistribution.