22# Copyright (c) Meta Platforms, Inc. and affiliates
33
44import functools
5+ from collections .abc import Sequence
56from dataclasses import dataclass , field
67from typing import cast , TypeVar
78
2122 unpad_tensor ,
2223)
2324from torch .distributed .tensor ._ops ._mask_buffer import MaskBuffer
25+ from torch .types import IntLikeType
2426
2527
2628__all__ = ["Placement" , "Shard" , "Replicate" , "Partial" ]
@@ -211,7 +213,7 @@ def _custom_chunk(
211213 @staticmethod
212214 @maybe_run_for_local_tensor
213215 def local_shard_size_and_offset (
214- curr_local_size : int ,
216+ curr_local_size : IntLikeType ,
215217 num_chunks : int ,
216218 rank : _RankTypeT ,
217219 ) -> tuple [_RankTypeT , _RankTypeT ]:
@@ -392,7 +394,7 @@ def _reduce_shard_tensor(
392394 def _maybe_pad_tensor (
393395 self ,
394396 local_tensor : torch .Tensor ,
395- logical_dim_size : int ,
397+ logical_dim_size : IntLikeType ,
396398 num_chunks : int ,
397399 ) -> torch .Tensor :
398400 from torch .fx .experimental .symbolic_shapes import guard_or_true
@@ -414,7 +416,7 @@ def _maybe_pad_tensor(
414416 def _maybe_unpad_tensor (
415417 self ,
416418 local_tensor : torch .Tensor ,
417- logical_dim_size : int ,
419+ logical_dim_size : IntLikeType ,
418420 num_chunks : int ,
419421 ) -> torch .Tensor :
420422 from torch .fx .experimental .symbolic_shapes import guard_or_true
@@ -434,7 +436,7 @@ def _to_replicate_tensor(
434436 local_tensor : torch .Tensor ,
435437 mesh : DeviceMesh ,
436438 mesh_dim : int ,
437- current_logical_shape : list [ int ],
439+ current_logical_shape : Sequence [ IntLikeType ],
438440 ) -> torch .Tensor :
439441 """
440442 This function all_gather all shards and return a tensor that
@@ -462,7 +464,7 @@ def _replicate_to_shard(
462464 local_tensor : torch .Tensor ,
463465 mesh : DeviceMesh ,
464466 mesh_dim : int ,
465- shard_index : int ,
467+ shard_index : IntLikeType ,
466468 ) -> torch .Tensor :
467469 """
468470 transform from replicated tensor to a sharded tensor on
@@ -489,11 +491,11 @@ def _get_shard_pad_size(
489491
490492 @staticmethod
491493 def _compute_padding_info (
492- current_logical_shape : list [ int ],
494+ current_logical_shape : Sequence [ IntLikeType ],
493495 num_chunks : int ,
494496 old_shard_dim : int ,
495497 new_shard_dim : int ,
496- ) -> tuple [bool , int , int , bool , int , int ]:
498+ ) -> tuple [bool , IntLikeType , int , bool , IntLikeType , int ]:
497499 from torch .fx .experimental .symbolic_shapes import guard_or_true
498500
499501 results = []
@@ -508,7 +510,7 @@ def _compute_padding_info(
508510 @staticmethod
509511 @maybe_run_for_local_tensor
510512 def _pad_for_new_shard_dim (
511- current_logical_shape : list [ int ],
513+ current_logical_shape : Sequence [ IntLikeType ],
512514 local_tensor : torch .Tensor ,
513515 num_chunks : int ,
514516 old_shard_dim : int ,
@@ -543,7 +545,7 @@ def _pad_for_new_shard_dim(
543545 @staticmethod
544546 @maybe_run_for_local_tensor
545547 def _unpad_for_new_shard_dim (
546- current_logical_shape : list [ int ],
548+ current_logical_shape : Sequence [ IntLikeType ],
547549 local_tensor : torch .Tensor ,
548550 num_chunks : int ,
549551 old_shard_dim : int ,
@@ -582,7 +584,7 @@ def _to_new_shard_dim(
582584 local_tensor : torch .Tensor ,
583585 mesh : DeviceMesh ,
584586 mesh_dim : int ,
585- current_logical_shape : list [ int ],
587+ current_logical_shape : Sequence [ IntLikeType ],
586588 new_shard_dim : int ,
587589 ) -> torch .Tensor :
588590 """
@@ -857,7 +859,7 @@ def _select_split_tensor(
857859 self ,
858860 tensor : torch .Tensor ,
859861 num_chunks : int ,
860- index : int ,
862+ index : IntLikeType ,
861863 * ,
862864 with_padding : bool = True ,
863865 contiguous : bool = True ,
@@ -891,7 +893,7 @@ def _to_replicate_tensor(
891893 local_tensor : torch .Tensor ,
892894 mesh : DeviceMesh ,
893895 mesh_dim : int ,
894- current_logical_shape : list [ int ],
896+ current_logical_shape : Sequence [ IntLikeType ],
895897 ) -> torch .Tensor :
896898 """
897899 Replay the replicate-to-shard process to understand how to stitch shards back.
@@ -1050,7 +1052,7 @@ def _replicate_to_strided_shard(
10501052 local_tensor : torch .Tensor ,
10511053 mesh : DeviceMesh ,
10521054 mesh_dim : int ,
1053- shard_index : int ,
1055+ shard_index : IntLikeType ,
10541056 ) -> torch .Tensor :
10551057 """
10561058 Transform from replicated tensor to a strided-sharded tensor on the current rank.
@@ -1097,7 +1099,7 @@ def _local_shard_size_and_offset(
10971099 @maybe_run_for_local_tensor
10981100 def local_shard_size_and_offset (
10991101 self ,
1100- curr_local_size : int ,
1102+ curr_local_size : IntLikeType ,
11011103 num_chunks : int ,
11021104 rank : RankType ,
11031105 return_first_offset : bool = True ,
@@ -1384,7 +1386,9 @@ def __init__(
13841386 @staticmethod
13851387 @maybe_run_for_local_tensor
13861388 def _mask_tensor (
1387- tensor : torch .Tensor , local_offset_on_dim : int , local_shard_size : int
1389+ tensor : torch .Tensor ,
1390+ local_offset_on_dim : IntLikeType ,
1391+ local_shard_size : IntLikeType ,
13881392 ) -> tuple [torch .Tensor , torch .Tensor ]:
13891393 # Build the input mask and save it for the current partial placement
13901394 # this is so that the output of embedding op can reuse the same partial
0 commit comments