[DTensor] Error on illegal view op during sharding prop#149764
[DTensor] Error on illegal view op during sharding prop#149764wconstab wants to merge 12 commits intogh/wconstab/399/basefrom
Conversation
Following the definition of legal sharding in the previous PR, this enforces that view ops raise a clear error message if an invalid uneven sharding is encountered during view op sharding propagation. The mechanism is to compute the current local shapes based on the existing input sharding, and the new local shapes based on the proposed output sharding and new global shape. The new global shape is pre-existing, but the computation of new local output shape is new and requires additional infrastructure, since normally we'd just run the aten operator to get the new shape, and in this case we can't even do that since we'd risk performing an impossible view (e.g. tensor of shape (512,) reshaped locally to (384,) in the particular example being addressed here. Fixes #143372 [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/149764
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit a050eae with merge base 56e67ba ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Following the definition of legal sharding in the previous PR, this enforces that view ops raise a clear error message if an invalid uneven sharding is encountered during view op sharding propagation. The mechanism is to compute the current local shapes based on the existing input sharding, and the new local shapes based on the proposed output sharding and new global shape. The new global shape is pre-existing, but the computation of new local output shape is new and requires additional infrastructure, since normally we'd just run the aten operator to get the new shape, and in this case we can't even do that since we'd risk performing an impossible view (e.g. tensor of shape (512,) reshaped locally to (384,) in the particular example being addressed here. Fixes #143372 ghstack-source-id: 26c48eb Pull Request resolved: #149764
| new_local_shape, _ = compute_local_shape_and_global_offset( | ||
| new_global_shape, mesh, output_placements | ||
| ) | ||
| if new_local_shape != expected_input_local_shape: |
There was a problem hiding this comment.
@XilunWu and I were discussing: this constraint is not accurate, its too picky. I think it's valid to change shape, as long as nelem is the same. Any other opinions on what the fully correct thing to assert here should be?
wanchaol
left a comment
There was a problem hiding this comment.
Hmmm I don't quite understand this PR, why we need to have those global/local shape calculations/checks in order to error out on illegal views (and only error out in certain cases)?
IMO the changes to error out should be very simple:
- add the
strict_viewoption as this PR did. - inside
reshape_strategy, just check ifstrict_viewandinput_src_spec != input_target_spec, then error out.
We should not rely on any shape checks to determine whether we want to error out, it should be just view ops should not do any redistribute
In the concrete example i'm trying to fix, the input_src_spec == input_target_spec == (Shard(0)). Both spec are exactly the same: The problem is, before the view, the local tensors would be I could not think of a way to raise the error inside the view sharding prop rule, unless I compute the concrete local shapes. Do you have a suggestion of another way to do this? |
Oh I see, I think for this case we could check in the Flatten operation around here https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_ops/_view_ops.py#L515. Basically if we want to make sure the first dim size of the flattened dimensions is evenly divisible to the mesh size, then it's "shardable". If not then we should error out with illegal views. The Split operation have some similar checks |
This makes sense. But how can I compute the size of the flattened dimension to make sure it is divisible by the mesh size? IIUC the reason this can be computed in the 'split' case is becuase when you're doing a split you have to pass the explicit new split sizes as part of the command, but for flatten you can use What about modifying this line so it passes more information into the new 'Flatten' such as the 'to_shape' field? (which is computed by infer_shape above) |
I think The DimSpec transformation thing is already able to get rid of |
|
Thinking about this more, I see a few places I could do this error checking, but they have tradeoffs.
For (2) and (3), i think the only missing piece is adding some shape information to the Flatten object. IIUC that's what you are proposing here:
I'll update the PR to try this and lmk if you have a preference between 1,2,3. Thanks for the suggestions! Edit: Its easy to add the 'group shape' to the Flatten. However, i'm not sure its good to raise the error locally when looking at the 'flatten' op. Would it be possible to have a complex view that first creates an 'incompatible' flatten, but then splits and the final result is 'compatible'? If so, then I think the safest place to error is at (2) based on the original/final shape or original/final nelem. |
IMO the right way to do this is to error when seeing a local Flatten DimSpec (i.e. add a group_shape to Flatten), The whole DimSpec logic is there to figure out a viable transformation steps for view like operations, this means that the dims need to actually go through those transformations conceptually in order to become the "output shape". So I think for the case of a complex view that have a "incompatible" flatten in the middle, and if this flatten need to error, then the whole transformation would not be compatible and should error too. However on the other side, if we handle it directly with the old concrete/global shape might not cover some complex view transformations. So I think handling Flatten inside 3. |
@wanchaol after thinking more about this, I think that using 'shape' (or group_shape) might not be the right approach. Or at least, I am not sure what criteria to apply to 'shape' to determine if flattening is illegal. Initially, I planned to assert that the new local shape matches the existing local shard shape, but i now believe this is not correct. Example: I could flatten and split some sharded dims, which is a violation, but then i could create a local shape that looks the same as the original shape and not catch this error by shape-checking. Instead, I think the best approach is to propagate another piece of information, about input dim sharding. The rules would be recursive rules. Valid to flatten if: Invalid, if And the way to implement these rules could be recursive propagation on the different helper classes Does this idea make sense? Might be easier to discuss over chat or VC if not. |
Following the definition of legal sharding in the previous PR, this enforces that view ops raise a clear error message if an invalid uneven sharding is encountered during view op sharding propagation. The mechanism is to compute the current local shapes based on the existing input sharding, and the new local shapes based on the proposed output sharding and new global shape. The new global shape is pre-existing, but the computation of new local output shape is new and requires additional infrastructure, since normally we'd just run the aten operator to get the new shape, and in this case we can't even do that since we'd risk performing an impossible view (e.g. tensor of shape (512,) reshaped locally to (384,) in the particular example being addressed here. Fixes #143372 cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 d4l3k c-p-i-o [ghstack-poisoned]
Adds explicit error checking during sharding propagation for view ops rather than relying on runtime errors during local op execution. Before: An error is thrown by aten.view op called by DTensor dispatch, because the local shard size is incompatible with the (incorrectly calculated) args to the view op. `RuntimeError: shape '[384]' is invalid for input of size 512` After: We raise more specific errors for cases of incompatible view operations during sharding propagation, before getting to runtime dispatch. `RuntimeError: Attempted to flatten an unevenly sharded dimension, which would require resharding the input. Please explicitly redistribute the tensor instead.` Details: Fixes #143372 [ghstack-poisoned]
Adds explicit error checking during sharding propagation for view ops rather than relying on runtime errors during local op execution. The motivation is twofold: 1. to provide clearer errors in currently failing cases 2. to clarify that _becuase_ view ops are expected to be lightweight ops, it would be bad to implicitly perform communication that would be needed to support these failing cases, and thus the right thing is to error clearly. Before: An error is thrown by aten.view op called by DTensor dispatch, because the local shard size is incompatible with the (incorrectly calculated) args to the view op. `RuntimeError: shape '[384]' is invalid for input of size 512` After: We raise more specific errors for cases of incompatible view operations during sharding propagation, before getting to runtime dispatch. `RuntimeError: Attempted to flatten an unevenly sharded dimension, which would require resharding the input. Please explicitly redistribute the tensor instead.` Change Summary: * add 'strict_view' kwarg to the helper methods that implement view/reshape op shard prop rules, so it can be decided op-by-op whether to raise these new errors * enabled errors just for the 'view' op in this PR * added two specific checks/errors that can occur during view ops. Details about new errors added for 'strict_view' ops: (1) a Flatten is only allowed if only the left-most input dim is sharded or all input dims are replicated (2) when Flatten is applied on a sharded left-most input dim that is unevenly sharded, the Flatten is also illegal since it would likely change the unevenness and require resharding. Follow ups * we should decide what other view-like ops behavior should be (this PR only addressed the 'view' op Fixes #143372 cc H-Huang awgu wanchaol fegin fduwjj wz337 d4l3k [ghstack-poisoned]
Adds explicit error checking during sharding propagation for view ops rather than relying on runtime errors during local op execution. Before: An error is thrown by aten.view op called by DTensor dispatch, because the local shard size is incompatible with the (incorrectly calculated) args to the view op. `RuntimeError: shape '[384]' is invalid for input of size 512` After: We raise more specific errors for cases of incompatible view operations during sharding propagation, before getting to runtime dispatch. `RuntimeError: Attempted to flatten an unevenly sharded dimension, which would require resharding the input. Please explicitly redistribute the tensor instead.` Details: Fixes #143372 ghstack-source-id: a190bd9 Pull Request resolved: #149764
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
|
@pytorchbot merge -f |
|
❌ 🤖 pytorchbot command failed: Try |
|
@pytorchbot merge -f "infra issue? regular merge timed out" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
For view/reshape ops validate evenly or unevenly sharded dtensor before getting to runtime dispatch, thrown more specific error before getting to runtime dispatch as implemented here pytorch#149764
For view/reshape ops validate evenly or unevenly sharded dtensor before getting to runtime dispatch, thrown more specific error before getting to runtime dispatch as implemented here pytorch#149764
…ble (#161950) This PR is a followup to #149764. In that PR, it only forbids illegal view due to `Flatten`; this PR also forbids illegal view caused by `Split`. This PR also updates the error message to be less about internal implementation details, which users may find confusing. Pull Request resolved: #161950 Approved by: https://github.com/ezyang
…ble (pytorch#161950) This PR is a followup to pytorch#149764. In that PR, it only forbids illegal view due to `Flatten`; this PR also forbids illegal view caused by `Split`. This PR also updates the error message to be less about internal implementation details, which users may find confusing. Pull Request resolved: pytorch#161950 Approved by: https://github.com/ezyang
…ble (pytorch#161950) This PR is a followup to pytorch#149764. In that PR, it only forbids illegal view due to `Flatten`; this PR also forbids illegal view caused by `Split`. This PR also updates the error message to be less about internal implementation details, which users may find confusing. Pull Request resolved: pytorch#161950 Approved by: https://github.com/ezyang
…ble (pytorch#161950) This PR is a followup to pytorch#149764. In that PR, it only forbids illegal view due to `Flatten`; this PR also forbids illegal view caused by `Split`. This PR also updates the error message to be less about internal implementation details, which users may find confusing. Pull Request resolved: pytorch#161950 Approved by: https://github.com/ezyang
…Tensor)"
nn.Linear(DTensor) got decomposed into view on DTensor, with error ` RuntimeError(
[rank1]: RuntimeError: ('Attempted to flatten multiple dimensions, with dimension 1 being sharded. ', 'It cannot be performed without redistribution, which is disallowed by the current operator.')`
still learning from a few PRs
* #149764
* #161950
* #161161
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci
[ghstack-poisoned]
…from nn.Linear(DTensor)"
nn.Linear(DTensor) got decomposed into view on DTensor, with error ` RuntimeError(
[rank1]: RuntimeError: ('Attempted to flatten multiple dimensions, with dimension 1 being sharded. ', 'It cannot be performed without redistribution, which is disallowed by the current operator.')`
still learning from a few PRs
* #149764
* #161950
* #161161
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci
[ghstack-poisoned]
…Tensor)"
nn.Linear(DTensor) got decomposed into view on DTensor, with error ` RuntimeError(
[rank1]: RuntimeError: ('Attempted to flatten multiple dimensions, with dimension 1 being sharded. ', 'It cannot be performed without redistribution, which is disallowed by the current operator.')`
still learning from a few PRs
* #149764
* #161950
* #161161
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci
[ghstack-poisoned]
Stack from ghstack (oldest at bottom):
Adds explicit error checking during sharding propagation for view ops
rather than relying on runtime errors during local op execution.
Before:
An error is thrown by aten.view op called by DTensor dispatch, because
the local shard size is incompatible with the (incorrectly calculated)
args to the view op.
RuntimeError: shape '[384]' is invalid for input of size 512After:
We raise more specific errors for cases of incompatible view operations
during sharding propagation, before getting to runtime dispatch.
RuntimeError: Attempted to flatten an unevenly sharded dimension, which would require resharding the input. Please explicitly redistribute the tensor instead.Change Summary:
add 'strict_view' kwarg to the helper methods that implement
view/reshape op shard prop rules, so it can be decided op-by-op whether
to raise these new errors
enabled errors just for the 'view' op in this PR
added two specific checks/errors that can occur during view ops.
Details:
sharded, since that would likely change the size/content of the
local_tensor and require redistribute
dim is a Shard() placment, becuase it would cause contiguity errors
without redistribution
Notes:
decompose to an illegal view that only works by performing a
redistribution: cartesian_prod, flatten, ravel, reshape, reshape_as, view, view_as, take_along_dim, kron
Follow Ups:
redistribution (ban them all, and document this)
Fixes #143372
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @d4l3k