Skip to content

Commit b8d53c6

Browse files
aorenstepytorchmergebot
authored andcommitted
Fix type annotation for _sym_get_coordinate (#177446)
Pull Request resolved: #177446 Approved by: https://github.com/Skylion007 ghstack dependencies: #172795
1 parent 42244bc commit b8d53c6

File tree

5 files changed

+44
-30
lines changed

5 files changed

+44
-30
lines changed

torch/distributed/device_mesh.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from torch.distributed import is_available
1313
from torch.distributed._mesh_layout import _MeshLayout
1414
from torch.distributed._pycute import IntTuple, is_int, suffix_product
15+
from torch.types import IntLikeType
1516
from torch.utils._typing_utils import not_none
1617

1718

@@ -1220,7 +1221,7 @@ def get_coordinate(self) -> tuple[int, ...] | None:
12201221
"""
12211222
return self._coordinate_on_dim
12221223

1223-
def _sym_get_coordinate(self, index: int) -> int:
1224+
def _sym_get_coordinate(self, index: int) -> IntLikeType:
12241225
import torch.distributed.config as config
12251226
from torch._guards import detect_fake_mode
12261227

torch/distributed/tensor/_collective_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
scatter,
2525
Work,
2626
)
27+
from torch.types import IntLikeType
2728

2829

2930
logger = logging.getLogger(__name__)
@@ -191,7 +192,9 @@ def pad_tensor(tensor: torch.Tensor, pad_dim: int, pad_size: int) -> torch.Tenso
191192

192193

193194
@maybe_run_for_local_tensor
194-
def unpad_tensor(tensor: torch.Tensor, pad_dim: int, pad_size: int) -> torch.Tensor:
195+
def unpad_tensor(
196+
tensor: torch.Tensor, pad_dim: int, pad_size: IntLikeType
197+
) -> torch.Tensor:
195198
from torch.fx.experimental.symbolic_shapes import guard_or_false
196199

197200
if guard_or_false(pad_size == 0):

torch/distributed/tensor/_random.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from torch.distributed.device_mesh import _get_device_handle, DeviceMesh
1212
from torch.distributed.tensor._dtensor_spec import DTensorSpec
1313
from torch.distributed.tensor.placement_types import _StridedShard, Shard
14+
from torch.types import IntLikeType
1415

1516

1617
logger = getLogger(__name__)
@@ -391,8 +392,8 @@ def _compute_rng_offsets(self, spec: DTensorSpec) -> tuple[int, int]:
391392
return start_offset_incr, end_offset_incr
392393

393394
def _calc_shard_linear_idx(
394-
self, shard_coord: list[int], shard_size: list[int]
395-
) -> int:
395+
self, shard_coord: Sequence[IntLikeType], shard_size: Sequence[IntLikeType]
396+
) -> IntLikeType:
396397
return _calc_shard_linear_idx(shard_coord, shard_size)
397398

398399

@@ -411,8 +412,8 @@ def _calc_first_shard_size(spec: DTensorSpec) -> list[int]:
411412

412413

413414
def _calc_shard_info(
414-
mesh_coordinate: Sequence[int], spec: DTensorSpec
415-
) -> tuple[list[int], list[int]]:
415+
mesh_coordinate: Sequence[IntLikeType], spec: DTensorSpec
416+
) -> tuple[list[IntLikeType], list[IntLikeType]]:
416417
mesh = spec.mesh
417418
# note: dim_map does not allow double sharding which is the FSDP(fully_shard)+TP
418419
# case. Replace the custom logic with dim_map once we support it.
@@ -436,10 +437,12 @@ def _calc_shard_info(
436437
raise AssertionError
437438
mesh_size = mesh.shape
438439
shard_idx_by_dim = []
439-
total_num_shards_by_dim = [] # total number of shards on each tensor dim
440+
total_num_shards_by_dim: list[
441+
IntLikeType
442+
] = [] # total number of shards on each tensor dim
440443
for mesh_dim in dim_map:
441-
shard_idx = 0
442-
total_num_shards = 1
444+
shard_idx: IntLikeType = 0
445+
total_num_shards: IntLikeType = 1
443446
# the tensor dim is sharded on more than 1 mesh dim
444447
if isinstance(mesh_dim, list):
445448
rank_coord = [mesh_coordinate[d] for d in mesh_dim]
@@ -454,10 +457,12 @@ def _calc_shard_info(
454457
return shard_idx_by_dim, total_num_shards_by_dim
455458

456459

457-
def _calc_shard_linear_idx(shard_coord: list[int], shard_size: list[int]) -> int:
460+
def _calc_shard_linear_idx(
461+
shard_coord: Sequence[IntLikeType], shard_size: Sequence[IntLikeType]
462+
) -> IntLikeType:
458463
# compute shard linear index
459-
shard_linear_idx = 0
460-
shard_coord_stride = 1
464+
shard_linear_idx: IntLikeType = 0
465+
shard_coord_stride: IntLikeType = 1
461466
for idx, size in zip(reversed(shard_coord), reversed(shard_size)):
462467
shard_linear_idx += idx * shard_coord_stride
463468
shard_coord_stride *= size

torch/distributed/tensor/_redistribute.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
Replicate,
3232
Shard,
3333
)
34+
from torch.types import IntLikeType
3435
from torch.utils._debug_mode import get_active_debug_mode
3536

3637

@@ -144,7 +145,7 @@ class _TransformInfo:
144145
mesh_dim: int
145146
src_dst_placements: tuple[Placement, Placement]
146147
# logical_shape on this mesh dimension
147-
logical_shape: list[int]
148+
logical_shape: Sequence[IntLikeType]
148149

149150
def __post_init__(self):
150151
if self.mesh_dim < 0:
@@ -1176,8 +1177,8 @@ def get_logical_shape(
11761177
src_state: "DTensorRedistributePlanner.DistState",
11771178
mesh_dim: int,
11781179
full_tensor_shape: tuple[int, ...],
1179-
) -> list[int]:
1180-
new_logical_shape = list(full_tensor_shape)
1180+
) -> list[IntLikeType]:
1181+
new_logical_shape: list[IntLikeType] = list(full_tensor_shape)
11811182
for entry in src_state.tensor_dim_to_mesh_dim:
11821183
tensor_dim = entry.tensor_dim
11831184
mesh_dims = entry.mesh_dims

torch/distributed/tensor/placement_types.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# Copyright (c) Meta Platforms, Inc. and affiliates
33

44
import functools
5+
from collections.abc import Sequence
56
from dataclasses import dataclass, field
67
from typing import cast, TypeVar
78

@@ -21,6 +22,7 @@
2122
unpad_tensor,
2223
)
2324
from torch.distributed.tensor._ops._mask_buffer import MaskBuffer
25+
from torch.types import IntLikeType
2426

2527

2628
__all__ = ["Placement", "Shard", "Replicate", "Partial"]
@@ -211,7 +213,7 @@ def _custom_chunk(
211213
@staticmethod
212214
@maybe_run_for_local_tensor
213215
def local_shard_size_and_offset(
214-
curr_local_size: int,
216+
curr_local_size: IntLikeType,
215217
num_chunks: int,
216218
rank: _RankTypeT,
217219
) -> tuple[_RankTypeT, _RankTypeT]:
@@ -392,7 +394,7 @@ def _reduce_shard_tensor(
392394
def _maybe_pad_tensor(
393395
self,
394396
local_tensor: torch.Tensor,
395-
logical_dim_size: int,
397+
logical_dim_size: IntLikeType,
396398
num_chunks: int,
397399
) -> torch.Tensor:
398400
from torch.fx.experimental.symbolic_shapes import guard_or_true
@@ -414,7 +416,7 @@ def _maybe_pad_tensor(
414416
def _maybe_unpad_tensor(
415417
self,
416418
local_tensor: torch.Tensor,
417-
logical_dim_size: int,
419+
logical_dim_size: IntLikeType,
418420
num_chunks: int,
419421
) -> torch.Tensor:
420422
from torch.fx.experimental.symbolic_shapes import guard_or_true
@@ -434,7 +436,7 @@ def _to_replicate_tensor(
434436
local_tensor: torch.Tensor,
435437
mesh: DeviceMesh,
436438
mesh_dim: int,
437-
current_logical_shape: list[int],
439+
current_logical_shape: Sequence[IntLikeType],
438440
) -> torch.Tensor:
439441
"""
440442
This function all_gather all shards and return a tensor that
@@ -462,7 +464,7 @@ def _replicate_to_shard(
462464
local_tensor: torch.Tensor,
463465
mesh: DeviceMesh,
464466
mesh_dim: int,
465-
shard_index: int,
467+
shard_index: IntLikeType,
466468
) -> torch.Tensor:
467469
"""
468470
transform from replicated tensor to a sharded tensor on
@@ -489,11 +491,11 @@ def _get_shard_pad_size(
489491

490492
@staticmethod
491493
def _compute_padding_info(
492-
current_logical_shape: list[int],
494+
current_logical_shape: Sequence[IntLikeType],
493495
num_chunks: int,
494496
old_shard_dim: int,
495497
new_shard_dim: int,
496-
) -> tuple[bool, int, int, bool, int, int]:
498+
) -> tuple[bool, IntLikeType, int, bool, IntLikeType, int]:
497499
from torch.fx.experimental.symbolic_shapes import guard_or_true
498500

499501
results = []
@@ -508,7 +510,7 @@ def _compute_padding_info(
508510
@staticmethod
509511
@maybe_run_for_local_tensor
510512
def _pad_for_new_shard_dim(
511-
current_logical_shape: list[int],
513+
current_logical_shape: Sequence[IntLikeType],
512514
local_tensor: torch.Tensor,
513515
num_chunks: int,
514516
old_shard_dim: int,
@@ -543,7 +545,7 @@ def _pad_for_new_shard_dim(
543545
@staticmethod
544546
@maybe_run_for_local_tensor
545547
def _unpad_for_new_shard_dim(
546-
current_logical_shape: list[int],
548+
current_logical_shape: Sequence[IntLikeType],
547549
local_tensor: torch.Tensor,
548550
num_chunks: int,
549551
old_shard_dim: int,
@@ -582,7 +584,7 @@ def _to_new_shard_dim(
582584
local_tensor: torch.Tensor,
583585
mesh: DeviceMesh,
584586
mesh_dim: int,
585-
current_logical_shape: list[int],
587+
current_logical_shape: Sequence[IntLikeType],
586588
new_shard_dim: int,
587589
) -> torch.Tensor:
588590
"""
@@ -857,7 +859,7 @@ def _select_split_tensor(
857859
self,
858860
tensor: torch.Tensor,
859861
num_chunks: int,
860-
index: int,
862+
index: IntLikeType,
861863
*,
862864
with_padding: bool = True,
863865
contiguous: bool = True,
@@ -891,7 +893,7 @@ def _to_replicate_tensor(
891893
local_tensor: torch.Tensor,
892894
mesh: DeviceMesh,
893895
mesh_dim: int,
894-
current_logical_shape: list[int],
896+
current_logical_shape: Sequence[IntLikeType],
895897
) -> torch.Tensor:
896898
"""
897899
Replay the replicate-to-shard process to understand how to stitch shards back.
@@ -1050,7 +1052,7 @@ def _replicate_to_strided_shard(
10501052
local_tensor: torch.Tensor,
10511053
mesh: DeviceMesh,
10521054
mesh_dim: int,
1053-
shard_index: int,
1055+
shard_index: IntLikeType,
10541056
) -> torch.Tensor:
10551057
"""
10561058
Transform from replicated tensor to a strided-sharded tensor on the current rank.
@@ -1097,7 +1099,7 @@ def _local_shard_size_and_offset(
10971099
@maybe_run_for_local_tensor
10981100
def local_shard_size_and_offset(
10991101
self,
1100-
curr_local_size: int,
1102+
curr_local_size: IntLikeType,
11011103
num_chunks: int,
11021104
rank: RankType,
11031105
return_first_offset: bool = True,
@@ -1384,7 +1386,9 @@ def __init__(
13841386
@staticmethod
13851387
@maybe_run_for_local_tensor
13861388
def _mask_tensor(
1387-
tensor: torch.Tensor, local_offset_on_dim: int, local_shard_size: int
1389+
tensor: torch.Tensor,
1390+
local_offset_on_dim: IntLikeType,
1391+
local_shard_size: IntLikeType,
13881392
) -> tuple[torch.Tensor, torch.Tensor]:
13891393
# Build the input mask and save it for the current partial placement
13901394
# this is so that the output of embedding op can reuse the same partial

0 commit comments

Comments
 (0)