Skip to content

Fix compatibility with latest PyTorch nightly#399

Merged
fmassa merged 5 commits intomainfrom
fmassa/fix_upstream_breakage
Mar 27, 2026
Merged

Fix compatibility with latest PyTorch nightly#399
fmassa merged 5 commits intomainfrom
fmassa/fix_upstream_breakage

Conversation

@fmassa
Copy link
Copy Markdown
Contributor

@fmassa fmassa commented Mar 26, 2026

Three recent upstream PyTorch changes broke autoparallel. This PR addresses all of them:

  1. Removed pointwise_strategy ([dtensor] Remove dead code from _pointwise_ops.py pytorch/pytorch#177208): Our custom native_layer_norm and native_layer_norm_backward rules imported this function. Replaced with a local _pointwise_strategy built on PyTorch's new single-dim strategy primitives (_common_pointwise_single_dim_strategy, _fill_single_dim_strategy_placeholders, expand_to_full_mesh_op_strategy).
  2. Activated view op assertion ([DTensor] Fix double-shard validation in propagate_shape_and_sharding  pytorch/pytorch#177973): Fixed a latent bug where double-shard validation in propagate_shape_and_sharding was dead code. The fix now raises when a sharded dim is not divisible by the mesh size, which is too strict for our use case — we enumerate all strategies with strict_view=False and skip incompatible ones. We now catch the AssertionError and skip the iteration.
  3. Added _maybe_unpad_tensor assertion ([Bugfix] Fix by copying sym shapes as needed pytorch/pytorch#178210): Added torch._check(orig_size >= logical_dim_size) for symbolic shape binding. This exposed a pre-existing bug in example_dcp.py where serialized sharding specs carried stale tensor_meta (from bs=256 on a (32,8) mesh) that was never updated when reapplying on a (2,2) mesh with bs=16. Fixed by pulling correct tensor_meta from the new AutoParallel instance's computed strategies.

Review order: propagation_rules.py first (the _pointwise_strategy function, then the view rule change), then example_dcp.py.

Authored with Claude.

PyTorch PR pytorch/pytorch#177973 fixed a latent bug where a comparison shard.dim == in_dim was always False (comparing int to InputDim), making the double-shard submesh_size validation in propagate_shape_and_sharding
dead code. The fix activated that validation, which now raises an AssertionError when a sharded dimension is not divisible by the mesh size (e.g. unflatten nheads=48 on mesh dim size 32).

This is overly strict for our use case — we call propagate_shape_and_sharding with strict_view=False to enumerate all possible sharding strategies, and incompatible ones should be silently skipped rather than raising. We now catch the
AssertionError and skip the iteration, since the replicated variant is already covered by another strategy.

Authored with Claude.
@fmassa fmassa requested a review from wconstab March 26, 2026 14:39
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 26, 2026
…intwise_strategy

PyTorch removed pointwise_strategy from torch.distributed.tensor._ops._pointwise_ops in pytorch/pytorch#177208 as part of a dead code cleanup. Our custom native_layer_norm and native_layer_norm_backward rules imported
this function, causing an ImportError on PyTorch nightlies after 2025-03-23.

This replaces the import with a local _pointwise_strategy that recomposes the same behavior from PyTorch's new single-dim strategy primitives (_common_pointwise_single_dim_strategy, _fill_single_dim_strategy_placeholders,
expand_to_full_mesh_op_strategy).

Authored with Claude.
@fmassa fmassa changed the title Work around upstream PyTorch view op assertion in propagation rules Fix compatibility with latest PyTorch nightly Mar 26, 2026
The issue was that _get_unique_placements from PyTorch's single-dim strategy path asserts each OpStrategy has exactly 1 strategy — but in autoparallel, input OpStrategy objects have multiple strategies (all possible placements). Replaced it
with an inline loop that collects unique placements from all strategies across all inputs.
)


def _pointwise_strategy(mesh, op_schema):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you remind me what the autoparallel preferred API is in general? i am wondering if there is a reason you need to do the work to chain together the singledim strategy with the expander in autop.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no preference in here on my side actually. When I first wrote this rule, I found it easier to use the pointwise strategy to handle the cases I cared about, but I'm happy to change it.
The implementation of layer_norm in PyTorch is in fact missing support for input / weight to have different number of placements, so if we could just fix that upstream in PyTorch and then remove the implementation in autoparallel I'd be happy as well


# _get_unique_placements assumes each OpStrategy has exactly one strategy
# (the single-dim path). In autoparallel, inputs have multiple strategies,
# so we collect unique placements from all of them.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, is this just a bug in the upstream, should we also be collecting from all, and then in the normal eager path all just happens to be 1? cc @anshul-si wdyt?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the previous old pointwise_op code, it seemed like there were some implemented features that would eventually enable us to support auto parallel like this. I agree that this should be implemented upstream after we meet our dtensor op coverage goals?

except AssertionError:
# PyTorch may raise when a sharded dim is not divisible by the
# mesh size (e.g. unflatten nheads=48 on mesh dim size=32).
# With strict_view=False this should demote to Replicate, but
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remind me- did we ever decouple views from reshapes and allow reshapes to do implicit redistribution, or did we just talk about that and in fact keep them the same?

does autoparallel actually want 'strict_view=False'? It should not be allowed to do redistributions on view ops either, should it? or does autop separately ensure that the root of a view has the same placement as the view itself?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is that AutoParallel inputs has many possible shardings, and some of them are invalid for some types of views. We would like those shardings to be filtered out, instead of raising an error during sharding propagation

fmassa added 2 commits March 26, 2026 18:37
The fix: _common_pointwise_single_dim_strategy produces num_outputs output entries per strategy row (e.g., 3 for native_layer_norm's out/mean/rstd). Since all pointwise outputs share the same placement, we collapse to a single output entry
before passing to expand_to_full_mesh_op_strategy with input_index=1. This way strategy.output_specs is a single DTensorSpec, matching the old pointwise_strategy behavior — the callers then construct the per-output specs (out, mean, rstd)
themselves.
…nfigurations in example_dcp

When reconstructing sharding placements from the serialized map for a different mesh, we were only updating .mesh on each DTensorSpec but leaving tensor_meta unchanged. Since the batch size scales with the mesh (bs = 8 * mesh.shape[0]), the
tensor shapes differ between the two phases (256 vs 16 on dim 0), so the specs carried stale shapes from the original configuration.

This was previously harmless — PyTorch's _maybe_unpad_tensor only used logical_dim_size (derived from tensor_meta.shape) to compute the unpad amount, and with even sharding that was a no-op. PyTorch PR
pytorch/pytorch#178210 added a torch._check(orig_size >= logical_dim_size) assertion that now catches the mismatch.

Fix: pull correct tensor_meta from the new AutoParallel instance's computed strategies, which reflect the actual shapes for the current mesh and batch size.

Authored with Claude.
@fmassa fmassa merged commit 157ab51 into main Mar 27, 2026
11 checks passed
@fmassa fmassa deleted the fmassa/fix_upstream_breakage branch March 27, 2026 10:46
pianpwk added a commit to pytorch/pytorch that referenced this pull request Apr 8, 2026
…m, RMSNorm FW/BW"

Removes op strategies for layernorm, RMS norm FWD/BWD, since they don't compose well with AutoParallel, in favor of single-dim strategies

I think this should fix meta-pytorch/autoparallel#142, and maybe allow us to delete the overrides in meta-pytorch/autoparallel#399, meta-pytorch/autoparallel#373




[ghstack-poisoned]
pianpwk added a commit to pytorch/pytorch that referenced this pull request Apr 8, 2026
Removes op strategies for layernorm, RMS norm FWD/BWD, since they don't compose well with AutoParallel, in favor of single-dim strategies

I think this should fix meta-pytorch/autoparallel#142, and maybe allow us to delete the overrides in meta-pytorch/autoparallel#399, meta-pytorch/autoparallel#373




[ghstack-poisoned]
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Apr 9, 2026
Removes op strategies for layernorm, RMS norm FWD/BWD, since they don't compose well with AutoParallel, in favor of single-dim strategies

I think this should fix meta-pytorch/autoparallel#142, and maybe allow us to delete the overrides in meta-pytorch/autoparallel#399, meta-pytorch/autoparallel#373

Pull Request resolved: #179173
Approved by: https://github.com/zpcore
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants