[dtensor] have DTensorSpec report how many shards on each tensor dimension#130587
[dtensor] have DTensorSpec report how many shards on each tensor dimension#130587XilunWu wants to merge 2 commits intogh/XilunWu/89/basefrom
Conversation
…nsion [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/130587
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit bcdbc23 with merge base dc7725c ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…tensor dimension" **Summary** Add a new property `num_shards_map` to `DTensorSpec` denoting how many shards each tensor dimension has. This is necessary for constructing `_StridedShard` placement when we call `distribute_tensor(dtensor_tp, dp_device_mesh, [Shard(0)])` and the `split_factor` argument will just be the number of shards on that sharding tensor dim. cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
| @property | ||
| def num_shards_map(self) -> List[int]: | ||
| """ | ||
| dim_map is a property we derive from `placements` of |
There was a problem hiding this comment.
nit: I think we may prefer to start the comment with "what is num_shards_map?" directly (at least at a high level) and then compare it with dim_map, or do you think that it is a requirement for users to know what is dim_map first before understanding what is num_shards_map?
| For example, we have a dist tensor of shape [18, 20, 30], | ||
| a device_mesh ([[0, 1, 2, 3], [4, 5, 6, 7]]), and placements | ||
| ([Shard(1), Shard(0)]), the num_shards_map of this distributed tensor | ||
| would be: [4, 2, 1]. |
There was a problem hiding this comment.
Can we also add a test? Maybe we can just this example as the test.
| For example, we have a dist tensor of shape [18, 20, 30], | ||
| a device_mesh ([[0, 1, 2, 3], [4, 5, 6, 7]]), and placements | ||
| ([Shard(1), Shard(0)]), the num_shards_map of this distributed tensor | ||
| would be: [4, 2, 1]. |
There was a problem hiding this comment.
This is great! Could we add one more example to show when a tensor_dim is being sharded multiple times, the shards will be calculated globally. I think this is the information that dim_map is not able to capture but num_shards_map can.
wanchaol
left a comment
There was a problem hiding this comment.
LGTM. Could we merge this PR to be together with the PR that actually uses it? (i.e. the FSDP2 integration PR), this way we can add a end to end test to cover this code.
…nsion ghstack-source-id: 82263e6 Pull Request resolved: pytorch/pytorch#130587
Stack from ghstack (oldest at bottom):
Summary
Add a new property
num_shards_maptoDTensorSpecdenoting how many shards each tensor dimension has. This is necessary for constructing_StridedShardplacement when we calldistribute_tensor(dtensor_tp, dp_device_mesh, [Shard(0)])and thesplit_factorargument will just be the number of shards on that sharding tensor dim.cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o