shard.dim == in_dim compares int to InputDim dataclass — always False. Makes the [Shard(0), Shard(0)] submesh check dead code, silently accepting invalid configs.
https://github.com/pytorch/pytorch/blob/d428a3f9c9e/torch/distributed/tensor/_ops/_view_ops.py#L676
# Silently returns invalid sharding instead of erroring
from torch.distributed.tensor._ops._view_ops import propagate_shape_and_sharding, dim_maps
from torch.distributed.tensor.placement_types import Shard
import torch
propagate_shape_and_sharding(
[Shard(0), Shard(0)], (12,),
dim_maps[torch.Tensor.view](torch.empty(12), [3, 4]),
(2, 3), # split dim size 3 not divisible by submesh 2*3=6
)
cc @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @aditvenk @xmfan @tianyu-l @XilunWu @SherlockNoMad @ppwwyyxx
shard.dim == in_dimcompares int to InputDim dataclass — always False. Makes the[Shard(0), Shard(0)]submesh check dead code, silently accepting invalid configs.https://github.com/pytorch/pytorch/blob/d428a3f9c9e/torch/distributed/tensor/_ops/_view_ops.py#L676
cc @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @aditvenk @xmfan @tianyu-l @XilunWu @SherlockNoMad @ppwwyyxx