[FSDP2] Fixed incorrect tensor meta after .to(dtype)#137593
[FSDP2] Fixed incorrect tensor meta after .to(dtype)#137593awgu wants to merge 1 commit intogh/awgu/650/basefrom
.to(dtype)#137593Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/137593
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 28daaa3 with merge base d1b87e2 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| if updated_local_tensor: | ||
| # Only change the local tensor object if needed | ||
| self.sharded_param._local_tensor = local_tensor[: self.sharded_size[0]] | ||
| self._sharding_spec = self.sharded_param._spec |
There was a problem hiding this comment.
Question. instead of cashing self._sharding_spec would it make sense to have it be a property that always just queries self.sharded_param._spec
There was a problem hiding this comment.
took a quick look -- this will require some refactoring, so I will defer this to a separate PR
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
## Overview
This PR adds a `shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]` arg to `fully_shard` that allows users to specify FSDP sharding on a nonzero tensor dim. If doing so, then the tensor dim size must be divisible by the FSDP shard world size.
```
# Example:
def shard_placement_fn(param: nn.Parameter) -> Optional[Shard]:
largest_dim = largest_dim_size = -1
for dim, dim_size in enumerate(param.shape):
if dim_size > largest_dim_size:
largest_dim = dim
largest_dim_size = dim_size
return Shard(largest_dim)
fully_shard(module, shard_placement_fn=shard_placement_fn)
```
## Follow-Ups
- **Copy kernels:** For all-gather copy-out, we currently copy-out to temporaries and then chunk-dim-0 -> cat-shard-dim, incurring an extra copy for parameters sharded on nonzero tensor dim. Similarly, for reduce-scatter copy-in, we currently chunk-shard-dim -> cat-dim-0, incurring an extra copy for gradients sharded on nonzero tensor dim. @yifuwang has ideas for adding additional split size args to the copy ops that allows fusing these extra copies into the existing all-gather copy-out and reduce-scatter copy-in.
Pull Request resolved: #137496
Approved by: https://github.com/weifengpy
ghstack dependencies: #137593
Stack from ghstack (oldest at bottom):
shard_placement_fnarg #137496.to(dtype)#137593This fixes #137522. After a method that changes to module parameters (like
.to(torch.float64)), we need to update theDTensorSpec, whoseTensorMeta's dtype may have changed.cc @XilunWu @H-Huang @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o