Skip to content

[DTensor] Dead double-shard validation in propagate_shape_and_sharding #177972

@stmcgovern

Description

@stmcgovern

shard.dim == in_dim compares int to InputDim dataclass — always False. Makes the [Shard(0), Shard(0)] submesh check dead code, silently accepting invalid configs.

https://github.com/pytorch/pytorch/blob/d428a3f9c9e/torch/distributed/tensor/_ops/_view_ops.py#L676

# Silently returns invalid sharding instead of erroring
from torch.distributed.tensor._ops._view_ops import propagate_shape_and_sharding, dim_maps
from torch.distributed.tensor.placement_types import Shard
import torch

propagate_shape_and_sharding(
    [Shard(0), Shard(0)], (12,),
    dim_maps[torch.Tensor.view](torch.empty(12), [3, 4]),
    (2, 3),  # split dim size 3 not divisible by submesh 2*3=6
)

cc @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @aditvenk @xmfan @tianyu-l @XilunWu @SherlockNoMad @ppwwyyxx

Metadata

Metadata

Assignees

Labels

bot-triagedThis is a label only to be used by the auto triage botmodule: dtensordistributed tensor tagoncall: distributedAdd this issue/PR to distributed oncall triage queue

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions