Fix compatibility with latest PyTorch nightly#399
Conversation
PyTorch PR pytorch/pytorch#177973 fixed a latent bug where a comparison shard.dim == in_dim was always False (comparing int to InputDim), making the double-shard submesh_size validation in propagate_shape_and_sharding dead code. The fix activated that validation, which now raises an AssertionError when a sharded dimension is not divisible by the mesh size (e.g. unflatten nheads=48 on mesh dim size 32). This is overly strict for our use case — we call propagate_shape_and_sharding with strict_view=False to enumerate all possible sharding strategies, and incompatible ones should be silently skipped rather than raising. We now catch the AssertionError and skip the iteration, since the replicated variant is already covered by another strategy. Authored with Claude.
…intwise_strategy PyTorch removed pointwise_strategy from torch.distributed.tensor._ops._pointwise_ops in pytorch/pytorch#177208 as part of a dead code cleanup. Our custom native_layer_norm and native_layer_norm_backward rules imported this function, causing an ImportError on PyTorch nightlies after 2025-03-23. This replaces the import with a local _pointwise_strategy that recomposes the same behavior from PyTorch's new single-dim strategy primitives (_common_pointwise_single_dim_strategy, _fill_single_dim_strategy_placeholders, expand_to_full_mesh_op_strategy). Authored with Claude.
The issue was that _get_unique_placements from PyTorch's single-dim strategy path asserts each OpStrategy has exactly 1 strategy — but in autoparallel, input OpStrategy objects have multiple strategies (all possible placements). Replaced it with an inline loop that collects unique placements from all strategies across all inputs.
| ) | ||
|
|
||
|
|
||
| def _pointwise_strategy(mesh, op_schema): |
There was a problem hiding this comment.
can you remind me what the autoparallel preferred API is in general? i am wondering if there is a reason you need to do the work to chain together the singledim strategy with the expander in autop.
There was a problem hiding this comment.
There is no preference in here on my side actually. When I first wrote this rule, I found it easier to use the pointwise strategy to handle the cases I cared about, but I'm happy to change it.
The implementation of layer_norm in PyTorch is in fact missing support for input / weight to have different number of placements, so if we could just fix that upstream in PyTorch and then remove the implementation in autoparallel I'd be happy as well
|
|
||
| # _get_unique_placements assumes each OpStrategy has exactly one strategy | ||
| # (the single-dim path). In autoparallel, inputs have multiple strategies, | ||
| # so we collect unique placements from all of them. |
There was a problem hiding this comment.
hmm, is this just a bug in the upstream, should we also be collecting from all, and then in the normal eager path all just happens to be 1? cc @anshul-si wdyt?
There was a problem hiding this comment.
In the previous old pointwise_op code, it seemed like there were some implemented features that would eventually enable us to support auto parallel like this. I agree that this should be implemented upstream after we meet our dtensor op coverage goals?
| except AssertionError: | ||
| # PyTorch may raise when a sharded dim is not divisible by the | ||
| # mesh size (e.g. unflatten nheads=48 on mesh dim size=32). | ||
| # With strict_view=False this should demote to Replicate, but |
There was a problem hiding this comment.
remind me- did we ever decouple views from reshapes and allow reshapes to do implicit redistribution, or did we just talk about that and in fact keep them the same?
does autoparallel actually want 'strict_view=False'? It should not be allowed to do redistributions on view ops either, should it? or does autop separately ensure that the root of a view has the same placement as the view itself?
There was a problem hiding this comment.
The issue is that AutoParallel inputs has many possible shardings, and some of them are invalid for some types of views. We would like those shardings to be filtered out, instead of raising an error during sharding propagation
The fix: _common_pointwise_single_dim_strategy produces num_outputs output entries per strategy row (e.g., 3 for native_layer_norm's out/mean/rstd). Since all pointwise outputs share the same placement, we collapse to a single output entry before passing to expand_to_full_mesh_op_strategy with input_index=1. This way strategy.output_specs is a single DTensorSpec, matching the old pointwise_strategy behavior — the callers then construct the per-output specs (out, mean, rstd) themselves.
…nfigurations in example_dcp When reconstructing sharding placements from the serialized map for a different mesh, we were only updating .mesh on each DTensorSpec but leaving tensor_meta unchanged. Since the batch size scales with the mesh (bs = 8 * mesh.shape[0]), the tensor shapes differ between the two phases (256 vs 16 on dim 0), so the specs carried stale shapes from the original configuration. This was previously harmless — PyTorch's _maybe_unpad_tensor only used logical_dim_size (derived from tensor_meta.shape) to compute the unpad amount, and with even sharding that was a no-op. PyTorch PR pytorch/pytorch#178210 added a torch._check(orig_size >= logical_dim_size) assertion that now catches the mismatch. Fix: pull correct tensor_meta from the new AutoParallel instance's computed strategies, which reflect the actual shapes for the current mesh and batch size. Authored with Claude.
…m, RMSNorm FW/BW" Removes op strategies for layernorm, RMS norm FWD/BWD, since they don't compose well with AutoParallel, in favor of single-dim strategies I think this should fix meta-pytorch/autoparallel#142, and maybe allow us to delete the overrides in meta-pytorch/autoparallel#399, meta-pytorch/autoparallel#373 [ghstack-poisoned]
Removes op strategies for layernorm, RMS norm FWD/BWD, since they don't compose well with AutoParallel, in favor of single-dim strategies I think this should fix meta-pytorch/autoparallel#142, and maybe allow us to delete the overrides in meta-pytorch/autoparallel#399, meta-pytorch/autoparallel#373 [ghstack-poisoned]
Removes op strategies for layernorm, RMS norm FWD/BWD, since they don't compose well with AutoParallel, in favor of single-dim strategies I think this should fix meta-pytorch/autoparallel#142, and maybe allow us to delete the overrides in meta-pytorch/autoparallel#399, meta-pytorch/autoparallel#373 Pull Request resolved: #179173 Approved by: https://github.com/zpcore
Three recent upstream PyTorch changes broke autoparallel. This PR addresses all of them:
native_layer_normandnative_layer_norm_backwardrules imported this function. Replaced with a local_pointwise_strategybuilt on PyTorch's new single-dim strategy primitives (_common_pointwise_single_dim_strategy,_fill_single_dim_strategy_placeholders,expand_to_full_mesh_op_strategy).propagate_shape_and_shardingwas dead code. The fix now raises when a sharded dim is not divisible by the mesh size, which is too strict for our use case — we enumerate all strategies withstrict_view=Falseand skip incompatible ones. We now catch theAssertionErrorand skip the iteration._maybe_unpad_tensorassertion ([Bugfix] Fix by copying sym shapes as needed pytorch/pytorch#178210): Addedtorch._check(orig_size >= logical_dim_size)for symbolic shape binding. This exposed a pre-existing bug inexample_dcp.pywhere serialized sharding specs carried stale tensor_meta (frombs=256on a(32,8)mesh) that was never updated when reapplying on a(2,2)mesh withbs=16. Fixed by pulling correcttensor_metafrom the newAutoParallelinstance's computed strategies.Review order:
propagation_rules.pyfirst (the_pointwise_strategyfunction, then the view rule change), thenexample_dcp.py.Authored with Claude.