Skip to content

[dtensor] have DTensorSpec report how many shards on each tensor dimension#130587

Closed
XilunWu wants to merge 2 commits intogh/XilunWu/89/basefrom
gh/XilunWu/89/head
Closed

[dtensor] have DTensorSpec report how many shards on each tensor dimension#130587
XilunWu wants to merge 2 commits intogh/XilunWu/89/basefrom
gh/XilunWu/89/head

Conversation

@XilunWu
Copy link
Contributor

@XilunWu XilunWu commented Jul 11, 2024

Stack from ghstack (oldest at bottom):

Summary
Add a new property num_shards_map to DTensorSpec denoting how many shards each tensor dimension has. This is necessary for constructing _StridedShard placement when we call distribute_tensor(dtensor_tp, dp_device_mesh, [Shard(0)]) and the split_factor argument will just be the number of shards on that sharding tensor dim.

cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

@pytorch-bot
Copy link

pytorch-bot bot commented Jul 11, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/130587

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit bcdbc23 with merge base dc7725c (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

…tensor dimension"


**Summary**
Add a new property `num_shards_map` to `DTensorSpec` denoting how many shards each tensor dimension has. This is necessary for constructing `_StridedShard` placement when we call `distribute_tensor(dtensor_tp, dp_device_mesh, [Shard(0)])` and the `split_factor` argument will just be the number of shards on that sharding tensor dim.

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
@property
def num_shards_map(self) -> List[int]:
"""
dim_map is a property we derive from `placements` of
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think we may prefer to start the comment with "what is num_shards_map?" directly (at least at a high level) and then compare it with dim_map, or do you think that it is a requirement for users to know what is dim_map first before understanding what is num_shards_map?

For example, we have a dist tensor of shape [18, 20, 30],
a device_mesh ([[0, 1, 2, 3], [4, 5, 6, 7]]), and placements
([Shard(1), Shard(0)]), the num_shards_map of this distributed tensor
would be: [4, 2, 1].
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also add a test? Maybe we can just this example as the test.

For example, we have a dist tensor of shape [18, 20, 30],
a device_mesh ([[0, 1, 2, 3], [4, 5, 6, 7]]), and placements
([Shard(1), Shard(0)]), the num_shards_map of this distributed tensor
would be: [4, 2, 1].
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great! Could we add one more example to show when a tensor_dim is being sharded multiple times, the shards will be calculated globally. I think this is the information that dim_map is not able to capture but num_shards_map can.

Copy link
Collaborator

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Could we merge this PR to be together with the PR that actually uses it? (i.e. the FSDP2 integration PR), this way we can add a end to end test to cover this code.

XilunWu added a commit that referenced this pull request Jul 22, 2024
…rrect full_tensor() result

ghstack-source-id: b2c26d2
Pull Request resolved: #130760

[dtensor] have DTensorSpec report how many shards on each tensor dimension

ghstack-source-id: b2c26d2
Pull Request resolved: #130587
XilunWu added a commit that referenced this pull request Aug 1, 2024
…rrect full_tensor() result

ghstack-source-id: 519abb0
Pull Request resolved: #130760

[dtensor] have DTensorSpec report how many shards on each tensor dimension

ghstack-source-id: 519abb0
Pull Request resolved: #130587
XilunWu added a commit that referenced this pull request Aug 6, 2024
…rrect full_tensor() result

ghstack-source-id: be8392f
Pull Request resolved: #130760

[dtensor] have DTensorSpec report how many shards on each tensor dimension

ghstack-source-id: be8392f
Pull Request resolved: #130587
XilunWu added a commit that referenced this pull request Aug 6, 2024
…rrect full_tensor() result

ghstack-source-id: 6f66d4c
Pull Request resolved: #130760

[dtensor] have DTensorSpec report how many shards on each tensor dimension

ghstack-source-id: 6f66d4c
Pull Request resolved: #130587
XilunWu added a commit that referenced this pull request Aug 7, 2024
…rrect full_tensor() result

ghstack-source-id: 0e83706
Pull Request resolved: #130760

[dtensor] have DTensorSpec report how many shards on each tensor dimension

ghstack-source-id: 0e83706
Pull Request resolved: #130587
XilunWu added a commit that referenced this pull request Aug 7, 2024
…rrect full_tensor() result

ghstack-source-id: a89062b
Pull Request resolved: #130760

[dtensor] have DTensorSpec report how many shards on each tensor dimension

ghstack-source-id: a89062b
Pull Request resolved: #130587
XilunWu added a commit that referenced this pull request Aug 8, 2024
…rrect full_tensor() result

ghstack-source-id: 1699a7c
Pull Request resolved: #130760

[dtensor] have DTensorSpec report how many shards on each tensor dimension

ghstack-source-id: 1699a7c
Pull Request resolved: #130587
@XilunWu XilunWu closed this Aug 26, 2024
@github-actions github-actions bot deleted the gh/XilunWu/89/head branch September 28, 2024 02:05
injiiiiil pushed a commit to injiiiiil/654 that referenced this pull request Oct 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor oncall: distributed Add this issue/PR to distributed oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants