[FSDPv2] Support MultiSlice#7044
Conversation
088da28 to
76c7435
Compare
| # another helper for the output. | ||
| partition_spec[0] = "fsdp" | ||
| if extra_data_axis: | ||
| partition_spec[0] = ("fsdp", extra_data_axis) |
There was a problem hiding this comment.
We usually have this reversed for DCN - (extra_data_axis, 'fsdp'). The axes should be in order of increasing network intensity in the mesh, and the order in the partition spec will impact the sharding.
| shard_output: Optional[Callable] = None, | ||
| auto_wrap_policy: Optional[Callable] = None, | ||
| auto_wrapper_callable: Optional[Callable] = None, | ||
| extra_data_axis: Optional[str] = None, |
There was a problem hiding this comment.
What do you think of calling it replica_axis instead of extra_data_axis?
There was a problem hiding this comment.
I think replica_axis is too tied to the underneath technology while users may only be familiar with data parallel, fsdp, and tensor parallel.
| xs.mark_sharding(x, mesh, (('data', 'fsdp'), None)) | ||
| output = model(x) | ||
| # Make sure output are sharded. | ||
| annotation = '{devices=[4,1]0,2,1,3}' |
There was a problem hiding this comment.
This would be different than x's sharding - x should have [4,1]0,1,2,3 with the iota mesh. I left a comment below, I think we should reverse the order in _prepare_spmd_partition_spec.
There was a problem hiding this comment.
My mistake. Thanks for pointing it out.
|
@JackCaoG The TPU CI doesn't seem running even with the label. |
|
yea I think they only check the label when CI is being run. It is OK if you have any changes and repush it will run, otherwise we can let head to check. |
|
I'm landing it. If the master TPU CI breaks, let's deal with that later. |
Summary:
This pull request adds the multi-slice support for FSDPv2. Basically, the default setup is to use the dcn axis as the data axis, and it means we only do data parallel over multi-slices. In the future, we could also support FSDP over mutli-slices.
Test Plan:
PJRT_DEVICE=TPU python test/spmd/test_fsdp_v2.py