Skip to content

[DTensor] Error on illegal view op during sharding prop#149764

Closed
wconstab wants to merge 12 commits intogh/wconstab/399/basefrom
gh/wconstab/399/head
Closed

[DTensor] Error on illegal view op during sharding prop#149764
wconstab wants to merge 12 commits intogh/wconstab/399/basefrom
gh/wconstab/399/head

Conversation

@wconstab
Copy link
Contributor

@wconstab wconstab commented Mar 21, 2025

Stack from ghstack (oldest at bottom):

Adds explicit error checking during sharding propagation for view ops
rather than relying on runtime errors during local op execution.

Before:
An error is thrown by aten.view op called by DTensor dispatch, because
the local shard size is incompatible with the (incorrectly calculated)
args to the view op.

RuntimeError: shape '[384]' is invalid for input of size 512

After:
We raise more specific errors for cases of incompatible view operations
during sharding propagation, before getting to runtime dispatch.

RuntimeError: Attempted to flatten an unevenly sharded dimension, which would require resharding the input. Please explicitly redistribute the tensor instead.

Change Summary:

add 'strict_view' kwarg to the helper methods that implement
view/reshape op shard prop rules, so it can be decided op-by-op whether
to raise these new errors
enabled errors just for the 'view' op in this PR
added two specific checks/errors that can occur during view ops.

Details:

  • View ops are never allowed to flatten a dimension that is unevenly
    sharded, since that would likely change the size/content of the
    local_tensor and require redistribute
  • View ops are also never allowed to flatten two dims if the rightmost
    dim is a Shard() placment, becuase it would cause contiguity errors
    without redistribution

Notes:

  • Disables support for several ops in test_dtensor_ops.py test, which
    decompose to an illegal view that only works by performing a
    redistribution: cartesian_prod, flatten, ravel, reshape, reshape_as, view, view_as, take_along_dim, kron

Follow Ups:

  • triage other view-like ops (besides aten::view) for using strict_view
  • look for other gaps where view-like ops could still perform
    redistribution (ban them all, and document this)

Fixes #143372

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @d4l3k

Following the definition of legal sharding in the previous PR, this
enforces  that view ops raise a clear error message if an invalid
uneven sharding is encountered during view op sharding propagation.

The mechanism is to compute the current local shapes based on the
existing input sharding, and  the new local shapes based on the proposed
output sharding and new global shape.  The new global shape is
pre-existing, but the computation of new local output shape is new and
requires additional infrastructure, since normally we'd just run the
aten operator to get the new shape, and in this case we can't even do
that since we'd risk performing an impossible view (e.g. tensor of shape
(512,) reshaped locally to (384,) in the particular example being
addressed here.

Fixes #143372

[ghstack-poisoned]
@pytorch-bot pytorch-bot bot added ciflow/inductor oncall: distributed Add this issue/PR to distributed oncall triage queue labels Mar 21, 2025
@pytorch-bot
Copy link

pytorch-bot bot commented Mar 21, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/149764

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit a050eae with merge base 56e67ba (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

wconstab added a commit that referenced this pull request Mar 21, 2025
Following the definition of legal sharding in the previous PR, this
enforces  that view ops raise a clear error message if an invalid
uneven sharding is encountered during view op sharding propagation.

The mechanism is to compute the current local shapes based on the
existing input sharding, and  the new local shapes based on the proposed
output sharding and new global shape.  The new global shape is
pre-existing, but the computation of new local output shape is new and
requires additional infrastructure, since normally we'd just run the
aten operator to get the new shape, and in this case we can't even do
that since we'd risk performing an impossible view (e.g. tensor of shape
(512,) reshaped locally to (384,) in the particular example being
addressed here.

Fixes #143372

ghstack-source-id: 26c48eb
Pull Request resolved: #149764
new_local_shape, _ = compute_local_shape_and_global_offset(
new_global_shape, mesh, output_placements
)
if new_local_shape != expected_input_local_shape:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@XilunWu and I were discussing: this constraint is not accurate, its too picky. I think it's valid to change shape, as long as nelem is the same. Any other opinions on what the fully correct thing to assert here should be?

@wconstab wconstab changed the title [DTensor] Error on illegal view op sharding prop [WIP][DTensor] Error on illegal view op sharding prop Mar 21, 2025
@wconstab wconstab requested review from XilunWu, tianyu-l and wanchaol and removed request for wanchaol March 21, 2025 21:01
Copy link
Collaborator

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm I don't quite understand this PR, why we need to have those global/local shape calculations/checks in order to error out on illegal views (and only error out in certain cases)?

IMO the changes to error out should be very simple:

  1. add the strict_view option as this PR did.
  2. inside reshape_strategy, just check if strict_view and input_src_spec != input_target_spec, then error out.

We should not rely on any shape checks to determine whether we want to error out, it should be just view ops should not do any redistribute

@wconstab
Copy link
Contributor Author

inside reshape_strategy, just check if strict_view and input_src_spec != input_target_spec, then error out.

In the concrete example i'm trying to fix, the input_src_spec == input_target_spec == (Shard(0)).

Both spec are exactly the same:

DTensorSpec(mesh=DeviceMesh('cuda', [0, 1, 2, 3], mesh_dim_names=('tp',)), placements=(Shard(dim=0),), tensor_meta=TensorMeta(shape=torch.Size(
[6, 256]), stride=(256, 1), dtype=torch.float32))

The problem is, before the view, the local tensors would be
Rank0-2: (2,256)
Rank3: (0, 256)
After the view, the local tensors 'should' be
Rank0-3: (384,)
But there is nothing keeping track of this inside the spec, and in reality the local data is more like
Rank0-2: (512,)
Rank3: (0,)

I could not think of a way to raise the error inside the view sharding prop rule, unless I compute the concrete local shapes. Do you have a suggestion of another way to do this?

@wanchaol
Copy link
Collaborator

wanchaol commented Mar 24, 2025

The problem is, before the view, the local tensors would be Rank0-2: (2,256) Rank3: (0, 256) After the view, the local tensors 'should' be Rank0-3: (384,) But there is nothing keeping track of this inside the spec, and in reality the local data is more like Rank0-2: (512,) Rank3: (0,)

I could not think of a way to raise the error inside the view sharding prop rule, unless I compute the concrete local shapes. Do you have a suggestion of another way to do this?

Oh I see, I think for this case we could check in the Flatten operation around here https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_ops/_view_ops.py#L515. Basically if we want to make sure the first dim size of the flattened dimensions is evenly divisible to the mesh size, then it's "shardable". If not then we should error out with illegal views. The Split operation have some similar checks

@wconstab
Copy link
Contributor Author

wconstab commented Mar 25, 2025

Oh I see, I think for this case we could check in the Flatten operation around here https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_ops/_view_ops.py#L515. Basically if we want to make sure the first dim size of the flattened dimensions is evenly divisible to the mesh size, then it's "shardable". If not then we should error out with illegal views. The Split operation have some similar checks

This makes sense. But how can I compute the size of the flattened dimension to make sure it is divisible by the mesh size? IIUC the reason this can be computed in the 'split' case is becuase when you're doing a split you have to pass the explicit new split sizes as part of the command, but for flatten you can use -1. I did put logic into DimSpec for computing 'concrete_shape' to solve this problem, but it sounds like that's not the way you had in mind?

What about modifying this line so it passes more information into the new 'Flatten' such as the 'to_shape' field? (which is computed by infer_shape above)
https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_ops/_view_ops.py#L367

@wanchaol
Copy link
Collaborator

This makes sense. But how can I compute the size of the flattened dimension to make sure it is divisible by the mesh size? IIUC the reason this can be computed in the 'split' case is because when you're doing a split you have to pass the explicit new split sizes as part of the command, but for flatten you can use -1. I did put logic into DimSpec for computing 'concrete_shape' to solve this problem, but it sounds like that's not the way you had in mind?

I think The DimSpec transformation thing is already able to get rid of -1 and give concrete dims? i.e. whenever you are in the Flatten cmd, the cmd.input_dims should all be InputDim that have concrete input dimensions already. For the shape/size of flattened dimension, because view is a complicated transformation, computing the full shape would not be always correct. I think what we should do is to add the group_shape to Flatten similar to Split here https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_ops/_view_ops.py#L368, and then when dealing with Flatten we know: 1. the from_shape generated in the process of decomposing views 2. the first flattened dim from the Flatten DimSpec. Then we can throw error correspondingly.

@wconstab
Copy link
Contributor Author

wconstab commented Mar 25, 2025

Thinking about this more, I see a few places I could do this error checking, but they have tradeoffs.

  1. view_groups function. This is where flatten is constructed. In this function, I already know the old global shape (6,256) and new global shape (starts as -1, gets inferred to 6*256), but I do not know the mesh dim shape and sharding spec. For example, I do not know that the tensor spec says 'shard(0)' or the size of the first mesh dim is '4', which is critical for determining whether the sharding will be uneven.

  2. register_op_strategy_map:reshape_strategy: This fn knows the op type (useful for asserting on specific op type, like 'strict_view' field. It also knows the mesh shape, so it solves the problem in (1). But it does not know the concrete shapes at all. It has access to the 'rules' which are DimSpec.

  3. propagate_shape_and_sharding: This fn knows the mesh_shape, and also has access to the rules just like (2). It also has logic specific to each rule type (e.g. Flatten) so there is a logical place to handle the case of flattening. A flag like 'strict_view' could also be passed into this function if needed.

For (2) and (3), i think the only missing piece is adding some shape information to the Flatten object. IIUC that's what you are proposing here:

I think what we should do is to add the group_shape to Flatten similar to Split

I'll update the PR to try this and lmk if you have a preference between 1,2,3. Thanks for the suggestions!

Edit: Its easy to add the 'group shape' to the Flatten. However, i'm not sure its good to raise the error locally when looking at the 'flatten' op. Would it be possible to have a complex view that first creates an 'incompatible' flatten, but then splits and the final result is 'compatible'? If so, then I think the safest place to error is at (2) based on the original/final shape or original/final nelem.

@wanchaol
Copy link
Collaborator

Edit: Its easy to add the 'group shape' to the Flatten. However, i'm not sure its good to raise the error locally when looking at the 'flatten' op. Would it be possible to have a complex view that first creates an 'incompatible' flatten, but then splits and the final result is 'compatible'? If so, then I think the safest place to error is at (2) based on the original/final shape or original/final nelem.

IMO the right way to do this is to error when seeing a local Flatten DimSpec (i.e. add a group_shape to Flatten), The whole DimSpec logic is there to figure out a viable transformation steps for view like operations, this means that the dims need to actually go through those transformations conceptually in order to become the "output shape". So I think for the case of a complex view that have a "incompatible" flatten in the middle, and if this flatten need to error, then the whole transformation would not be compatible and should error too. However on the other side, if we handle it directly with the old concrete/global shape might not cover some complex view transformations.

So I think handling Flatten inside 3. propagate_shape_and_sharding should give the correct behaviors.

@wconstab
Copy link
Contributor Author

wconstab commented Apr 3, 2025

IMO the right way to do this is to error when seeing a local Flatten DimSpec (i.e. add a group_shape to Flatten),

@wanchaol after thinking more about this, I think that using 'shape' (or group_shape) might not be the right approach. Or at least, I am not sure what criteria to apply to 'shape' to determine if flattening is illegal. Initially, I planned to assert that the new local shape matches the existing local shard shape, but i now believe this is not correct. Example: I could flatten and split some sharded dims, which is a violation, but then i could create a local shape that looks the same as the original shape and not catch this error by shape-checking.

Instead, I think the best approach is to propagate another piece of information, about input dim sharding. The rules would be recursive rules.

Valid to flatten if:
- all replica dims
- a shard on the leftmost side, and replicated dims to the right

Invalid, if
- two shard dims (except if they were created by a split on a shard dim)
- i showed a couple examples where this is definitely invalid, but i'm not sure if its gauranteed to be invalid?
- one or more replicas on the left followed by a shard on the right

And the way to implement these rules could be recursive propagation on the different helper classes
- Split: Shard -> (Shard, Shard), Replica -> (Replica, Replica)
- Flatten: (Shard, Replicate...) -> Shard, (Replicate, ...) -> Replicate

Does this idea make sense? Might be easier to discuss over chat or VC if not.

Following the definition of legal sharding in the previous PR, this
enforces  that view ops raise a clear error message if an invalid
uneven sharding is encountered during view op sharding propagation.

The mechanism is to compute the current local shapes based on the
existing input sharding, and  the new local shapes based on the proposed
output sharding and new global shape.  The new global shape is
pre-existing, but the computation of new local output shape is new and
requires additional infrastructure, since normally we'd just run the
aten operator to get the new shape, and in this case we can't even do
that since we'd risk performing an impossible view (e.g. tensor of shape
(512,) reshaped locally to (384,) in the particular example being
addressed here.

Fixes #143372

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 d4l3k c-p-i-o

[ghstack-poisoned]
Adds explicit error checking during sharding propagation for view ops
rather than relying on runtime errors during local op execution.

Before:
An error is thrown by aten.view op called by DTensor dispatch, because
the local shard size is incompatible with the (incorrectly calculated)
args to the view op.

`RuntimeError: shape '[384]' is invalid for input of size 512`

After:
We raise more specific errors for cases of incompatible view operations
during sharding propagation, before getting to runtime dispatch.

`RuntimeError: Attempted to flatten an unevenly sharded dimension, which would require resharding the input. Please explicitly redistribute the tensor instead.`

Details:

Fixes #143372

[ghstack-poisoned]
@wconstab wconstab changed the title [WIP][DTensor] Error on illegal view op sharding prop [DTensor] Error on illegal view op during sharding prop Apr 11, 2025
@wconstab wconstab added the release notes: distributed (dtensor) release notes category label Apr 11, 2025
Adds explicit error checking during sharding propagation for view ops
rather than relying on runtime errors during local op execution.  

The motivation is twofold:
1. to provide clearer errors in currently failing cases
2. to clarify that _becuase_ view ops are expected to be lightweight ops, it would be bad to implicitly perform communication that would be needed to support these failing cases, and thus the right thing is to error clearly.

Before:
An error is thrown by aten.view op called by DTensor dispatch, because
the local shard size is incompatible with the (incorrectly calculated)
args to the view op.

`RuntimeError: shape '[384]' is invalid for input of size 512`

After:
We raise more specific errors for cases of incompatible view operations
during sharding propagation, before getting to runtime dispatch.

`RuntimeError: Attempted to flatten an unevenly sharded dimension, which would require resharding the input. Please explicitly redistribute the tensor instead.`

Change Summary:

* add 'strict_view' kwarg to the helper methods that implement view/reshape op shard prop rules, so it can be decided op-by-op whether to raise these new errors
* enabled errors just for the 'view' op in this PR
* added two specific checks/errors that can occur during view ops.  

Details about new errors added for 'strict_view' ops:
(1) a Flatten is only allowed if only the left-most input dim is sharded or all input dims are replicated
(2) when Flatten is applied on a sharded left-most input dim that is unevenly sharded, the Flatten is also illegal since it would likely change the unevenness and require resharding.


Follow ups
* we should decide what other view-like ops behavior should be (this PR only addressed the 'view' op

Fixes #143372

cc H-Huang awgu wanchaol fegin fduwjj wz337 d4l3k

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Apr 11, 2025
Adds explicit error checking during sharding propagation for view ops
rather than relying on runtime errors during local op execution.

Before:
An error is thrown by aten.view op called by DTensor dispatch, because
the local shard size is incompatible with the (incorrectly calculated)
args to the view op.

`RuntimeError: shape '[384]' is invalid for input of size 512`

After:
We raise more specific errors for cases of incompatible view operations
during sharding propagation, before getting to runtime dispatch.

`RuntimeError: Attempted to flatten an unevenly sharded dimension, which would require resharding the input. Please explicitly redistribute the tensor instead.`

Details:

Fixes #143372

ghstack-source-id: a190bd9
Pull Request resolved: #149764
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@wconstab
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@wconstab
Copy link
Contributor Author

@pytorchbot merge -f

@pytorch-bot
Copy link

pytorch-bot bot commented Apr 28, 2025

❌ 🤖 pytorchbot command failed:

@pytorchbot merge: error: argument -f/--force: expected one argument

usage: @pytorchbot merge [-f MESSAGE | -i] [-ic] [-r [{viable/strict,main}]]

Try @pytorchbot --help for more info.

@wconstab
Copy link
Contributor Author

@pytorchbot merge -f "infra issue? regular merge timed out"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions github-actions bot deleted the gh/wconstab/399/head branch June 12, 2025 02:23
dayanandav added a commit to dayanandav/pytorch that referenced this pull request Aug 21, 2025
For view/reshape ops validate evenly or unevenly sharded
dtensor before getting to runtime dispatch, thrown more
specific error before getting to runtime dispatch as
implemented here pytorch#149764
dayanandav added a commit to dayanandav/pytorch that referenced this pull request Aug 22, 2025
For view/reshape ops validate evenly or unevenly sharded
dtensor before getting to runtime dispatch, thrown more
specific error before getting to runtime dispatch as
implemented here pytorch#149764
pytorchmergebot pushed a commit that referenced this pull request Sep 3, 2025
…ble (#161950)

This PR is a followup to #149764.

In that PR, it only forbids illegal view due to `Flatten`; this PR also forbids illegal view caused by `Split`.

This PR also updates the error message to be less about internal implementation details, which users may find confusing.

Pull Request resolved: #161950
Approved by: https://github.com/ezyang
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…ble (pytorch#161950)

This PR is a followup to pytorch#149764.

In that PR, it only forbids illegal view due to `Flatten`; this PR also forbids illegal view caused by `Split`.

This PR also updates the error message to be less about internal implementation details, which users may find confusing.

Pull Request resolved: pytorch#161950
Approved by: https://github.com/ezyang
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
…ble (pytorch#161950)

This PR is a followup to pytorch#149764.

In that PR, it only forbids illegal view due to `Flatten`; this PR also forbids illegal view caused by `Split`.

This PR also updates the error message to be less about internal implementation details, which users may find confusing.

Pull Request resolved: pytorch#161950
Approved by: https://github.com/ezyang
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
…ble (pytorch#161950)

This PR is a followup to pytorch#149764.

In that PR, it only forbids illegal view due to `Flatten`; this PR also forbids illegal view caused by `Split`.

This PR also updates the error message to be less about internal implementation details, which users may find confusing.

Pull Request resolved: pytorch#161950
Approved by: https://github.com/ezyang
weifengpy added a commit that referenced this pull request Oct 3, 2025
…Tensor)"

nn.Linear(DTensor) got decomposed into view on DTensor, with error ` RuntimeError(
[rank1]: RuntimeError: ('Attempted to flatten multiple dimensions, with dimension 1 being sharded. ', 'It cannot be performed without redistribution, which is disallowed by the current operator.')`

still learning from a few PRs
* #149764 
* #161950 
* #161161 



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
weifengpy added a commit that referenced this pull request Oct 3, 2025
…from nn.Linear(DTensor)"

nn.Linear(DTensor) got decomposed into view on DTensor, with error ` RuntimeError(
[rank1]: RuntimeError: ('Attempted to flatten multiple dimensions, with dimension 1 being sharded. ', 'It cannot be performed without redistribution, which is disallowed by the current operator.')`

still learning from a few PRs
* #149764 
* #161950 
* #161161 



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
weifengpy added a commit that referenced this pull request Oct 3, 2025
…Tensor)"

nn.Linear(DTensor) got decomposed into view on DTensor, with error ` RuntimeError(
[rank1]: RuntimeError: ('Attempted to flatten multiple dimensions, with dimension 1 being sharded. ', 'It cannot be performed without redistribution, which is disallowed by the current operator.')`

still learning from a few PRs
* #149764 
* #161950 
* #161161 



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (dtensor) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants