Add utility functions for distributed checkpointing#5128
Conversation
5872316 to
d27834a
Compare
|
cc @yashs97 |
d27834a to
318074b
Compare
| // Clamp the end of the slice to the tensor shape to accurately reflect | ||
| // Clamp the slice bounds to the tensor shape to accurately reflect | ||
| // the shard size without padding. | ||
| int start = std::min(n_j * shard_shape[j], tensor_shape[j]); |
There was a problem hiding this comment.
Do you catch a bug or this is a more like a safeguard?
If n_j * shard_shape[j] is going to be > tensor_shape[j], it means (n_j + 1) * shard_shape[j] is certainly going to be larger than tensor_shape[j]. Therefore, for that scenario, start will be equal to end and equals to tensor_shape[j]. And that slice seems meaningless.
There was a problem hiding this comment.
I would call this a latent bug, but it wasn't breaking anything because torch indexing handles negative-length indices as though they were empty. It just breaks the expectation that stop - start reflects the size of the unpadded shard, which we rely on in distributed checkpointing.
You're right - these index slices will end up empty, but this is the desired outcome when the shard consists entirely of padding.
318074b to
cfcd622
Compare
| return ShardingType::TUPLE; | ||
| case xla::OpSharding::OTHER: | ||
| // OTHER sharding can indicate either PARTIAL or TILED sharding. | ||
| return sharding.replicate_on_last_tile_dim() ? ShardingType::PARTIAL |
There was a problem hiding this comment.
This seems pretty hacky. But I guess we don't have other ways round?
There was a problem hiding this comment.
My understanding is that we distinguish partial replication as a different sharding type whereas XLA treats partial and tiled as the same type OTHER. @yeounoh could you confirm?
There was a problem hiding this comment.
Yes, this is actually the correct way, the compiler treats TILED and PARTIAL the same as the OTHER type. The differences between the two would be how the tile shards are assigned to difference devices.
|
@jonb377 @alanwaketan is this pr ready to merge? |
|
Yes, I'll merge after TPU CI finishes |
This change adds a few utility functions to support distributed checkpointing. The following changes are included:
sharding_typetoXLAShardedTensorto get the ShardingTypewrap_if_shardedto converttorch.TensorintoXLAShardedTensorif the underlying data is sharded.devicesparameter from_get_local_shard_indicesand instead always return the shard indices in the order of the shardsstartbound of the index slices to the tensor's size.unpadded_dataproperty ofXLAShard