Implement SPMDSavePlanner to take distributed checkpoints#5170
Implement SPMDSavePlanner to take distributed checkpoints#5170
Conversation
5f4fa0f to
70b8f71
Compare
| self.sharded_state_dict[fqn].load_local_shards_(local_shards) | ||
|
|
||
|
|
||
| def _create_write_item_from_indices(fqn: str, shard_index: int, |
There was a problem hiding this comment.
Just noticed that you created all the helpers as the private functions of the module instead of private methods within each class. Just curious on why? You don't need to change this.
There was a problem hiding this comment.
Ah no real reason, I was just following the pattern from the upstream implementation - helpers to generate Read/WriteItems are private in the module for the DefaultPlanners.
There was a problem hiding this comment.
I still won't call myself a Python expert and that's why I ask questions on python programming patterns when they contradict to my C++ instinct.
alanwaketan
left a comment
There was a problem hiding this comment.
Generally LGTM. Thanks for shaping the distributed checkpointing story so quick.
| if index.fqn in self.unsharded_state_dict: | ||
| return self.unsharded_state_dict[index.fqn] | ||
|
|
||
| if index.fqn not in self._local_shards: |
There was a problem hiding this comment.
It seems to me that the logic of having _local_shards is not necessary. You are basically just using it to keep track of duplicated writes. Will that happen? For LoadPlanner, it is necessary because you only want to load the shards to XLAShardedTensor when all the shards are in the host memory and then you need to count it. I'm not sure why the SavePlanner needs this logic. Am I missing anything?
There was a problem hiding this comment.
The motivation is that each call to XLAShardedTensor::local_shards will move all of the shards from device to CPU, and since we only need one shard at a time it's more efficient to only transfer once. I'll expand the comment in __init__ to explain the need here.
There was a problem hiding this comment.
That makes sense to me now. Thanks for the explanation!
e2f6f66 to
3d67c83
Compare
|
@alanwaketan FYI I removed the distributed checkpointing test from CI, some of the APIs we're depending on aren't stable and we may need to adjust the imports once the upstream exposes them. Once they're stable, I'll add back to CI. |
Can you be more specific? And have you talked to @kumpera? |
|
@alanwaketan Across both planners, the APIs we're taking a dependency on are:
I've spoken with @kempura, and he's looking into which he can make stable. They're pretty small helper functions for the most part, so we can reimplement here if they can't be made stable in the upstream. |
Thanks for the updates. Can we dup the code now? And leave a GH issue and TODO in the code to follow up? I'm hesitate to call the MVP feature complete without a test. |
da52717 to
6bd0af5
Compare
6bd0af5 to
094e5a3
Compare
|
@alanwaketan I've pulled all of the unstable dependencies into |
|
Thanks, @jonb377! |
This implements the SavePlanner interface from torch.distributed.checkpoint. This implementation only directly handles sharded tensors and relies on the default planner's logic for everything else.
A high-level overview of each of the SavePlanner interface methods:
set_up_planner: Called with the state_dict to be checkpointed. Our implementation will split the state_dict into a sharded and unsharded portion so that we can defer to the default planner logic for the unsharded part.create_local_plan: WriteItems are generated for every item in the state_dict. The default planner is used for the unsharded objects, and we generate a WriteItem for each shard of each XLAShardedTensor with non-REPLICATED sharding type.create_global_plan: The coordinator process makes any global decisions for the restoration. There is no custom logic here.finish_plan: The process can adjust its plan after global coordination. Again, no custom logic here.resolve_data: Return the data to be written for a given WriteItem. We return the local shard for sharded tensors or relevant portionsThis change also enables distributed checkpointing tests in CI.