[DTensor] enable single-dim strategy for addmm and baddbmm#172387
[DTensor] enable single-dim strategy for addmm and baddbmm#172387weifengpy wants to merge 12 commits intogh/weifengpy/51/basefrom
Conversation
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/172387
Note: Links to docs will display an error until the docs builds have been completed. ❌ 9 New FailuresAs of commit c1c550d with merge base 011e373 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
| from torch.distributed.tensor._ops.utils import infer_broadcast_dims_map | ||
|
|
||
| mm_strategies = gen_single_dim_einsum_strategies(mm_equation) | ||
| self_meta = cast(TensorMeta, args_schema[0]) # bias |
There was a problem hiding this comment.
Why not do an assert isinstance here? We should use cast sparingly
There was a problem hiding this comment.
good catch! I updated the PR to use assert
| broadcast_dims_map = infer_broadcast_dims_map(mm_out_shape, self_meta.shape) | ||
|
|
||
| # Add bias placement to each strategy | ||
| addmm_like_strategies: list[list[Placement | _ShardingPlaceholder]] = [] |
There was a problem hiding this comment.
brainstorming- would it be cleaner to add an option to gen_single_dim_einsum_strategies to insert an extra bias placement, rather than having to iteratively update the einsum strategies in the separate helper? (maybe not, but wdyt?)
There was a problem hiding this comment.
totally reasonable! the logic are tighter now in gen_single_dim_einsum_strategies. I updated the PR for another review
There was a problem hiding this comment.
Hah, I was going to ask why we modify gen_single_dim_einsum_strategies instead of adding a new strategy to accept bias on top of gen_single_dim_einsum_strategies. Looks like this solution has been challenged.
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
| return output_placement | ||
|
|
||
| if isinstance(output_placement, Partial): | ||
| return Partial() |
There was a problem hiding this comment.
This isn't good. We should actually return output placement so we also inherit its reduce op
There was a problem hiding this comment.
good catch! I am cloning placement now and added a test to catch Partial(avg)
| return Partial() | ||
| elif isinstance(output_placement, Replicate): | ||
| return Replicate() | ||
| elif isinstance(output_placement, _ShardingPlaceholder): |
There was a problem hiding this comment.
Seems like we only need this case as if, and then we can have else that covers replicate/partial and returns output placement.
It might be better practice to do something like output placement.clone() so we aren't sharing references but if it's a dataclass we can't mutate then it's ok this way
There was a problem hiding this comment.
good sugestion! I updated the PR to have simpler if-else for _ShardingPlaceholder
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 197020b Pull Request resolved: pytorch/pytorch#172387
ghstack-source-id: f40c4e1 Pull Request resolved: pytorch/pytorch#172387
ghstack-source-id: 17a9ec7 Pull Request resolved: pytorch/pytorch#172387
|
@pytorchmergebot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / linux-jammy-rocm-py3.10 / test (distributed, 3, 3, linux.rocm.gpu.gfx942.4) Details for Dev Infra teamRaised by workflow job |
|
@pytorchmergebot merge -f "unrelated error" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
Summary:
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags: