Implement SPMDLoadPlanner to enable distributed checkpoint loading#5130
Implement SPMDLoadPlanner to enable distributed checkpoint loading#5130
Conversation
|
cc @yashs97 |
d27834a to
318074b
Compare
75e89bf to
f043082
Compare
318074b to
cfcd622
Compare
f043082 to
0fd0174
Compare
alanwaketan
left a comment
There was a problem hiding this comment.
Thanks @jonb377. This is great. Left a few comments.
| # Flatten the state_dict to allow separating sharded XLA tensors from | ||
| # types that can be handled by the default planner, and ensure all sharded | ||
| # tensors are wrapped in XLAShardedTensor | ||
| state_dict, self.mappings = flatten_state_dict(state_dict) |
There was a problem hiding this comment.
It looks like the upstream DefaultLoadPlanner doesn't always flatten the state_dict? Do you know why and what are the downsides of flattening the state_dict?
There was a problem hiding this comment.
Hmm I actually don't know why it's optional in the DefaultLoadPlanner, I'll follow up on that. The flattened state_dict is easier for us to work with since we need to split the input into sharded/unsharded parts, which is difficult if the state_dict is nested.
One downside of flattening is that we aren't operating directly on the input state_dict, so non-tensor items need to be mapped back to the original state_dict.
There was a problem hiding this comment.
But that down side is handled by the default planner helpers, right?
There was a problem hiding this comment.
It's optional on DefaultLoadPlaner mostly to help with testing. Its default value is True, so most users will always have it enabled.
| self.assertFalse(self._same_shard_data(xtensor.local_shards, old_shards)) | ||
|
|
||
| def test_load_state_dict(self): | ||
| dist.init_process_group( |
There was a problem hiding this comment.
Why do we need the process_group?
There was a problem hiding this comment.
torch.distributed.checkpoint requires a process group to coordinate merging the local plans into a global plan. Even though this is just a single instance, we still need to initialize the PG to use the save_state_dict and load_state_dict APIs.
There was a problem hiding this comment.
Hmmm, is this because we are saving the unsharded CPU tensor? If it's required all the time, I feel like it's not a good UX.
There was a problem hiding this comment.
A PG is always required, but I'll need to check if we can use the XLA backend - my first attempt to replace gloo with an xla group errored out when creating the global plan.
If we build a higher-level interface, this could be hidden this from the user.
There was a problem hiding this comment.
Most of our own distributed API does't require PG. So, maybe we should try keep the consistency. Do you know why they require PG?
There was a problem hiding this comment.
It's used to centrally coordinate the global plans, e.g. all worker plans are sent to the coordinator using collectives in the process group.
Actually, I see we have the option to run without a PG by setting no_dist=True in the (save|load)_state_dict functions, but we lose out on the global coordination. This shouldn't be an issue for the LoadPlanner, and I suspect we can make the SavePlanner work without global planning as well - we'll just need to track replica ranks and dedupe tensors locally.
Thanks for bringing this up Jiewen, I'll experiment with it and update the PR.
There was a problem hiding this comment.
Synced offline - when taking a checkpoint, we would lose the consistency provided by the process group (the coordinator waits until all workers are finished before writing metadata, which depends on a PG).
I'll use no_dist=True in the tests to simplify the test code, but we will still expect users to have a CPU process group when taking a checkpoint.
| tmpdir = tempfile.mkdtemp() | ||
|
|
||
| # Save an unsharded model using the default planner in dist_cp.save_state_dict | ||
| model = self.SimpleLinear().to(xm.xla_device()) |
There was a problem hiding this comment.
Can you illustrate more on how a saved unsharded model can be loaded into a sharded model?
There was a problem hiding this comment.
This is a feature from the upstream - resharding is handled transparently for us when we use the create_read_items_for_chunk_list API to create our ReadItems. Even sharded checkpoints should be able to be restored into different device meshes or unsharded models.
cd3d043 to
363ad31
Compare
3476c8d to
ad1647f
Compare
ad1647f to
9bcd879
Compare
|
@jonb377 One small thing I would suggest you to do is to use |
alanwaketan
left a comment
There was a problem hiding this comment.
Generally speaking, it LGTM! Thanks for getting this complicated work being done so quick. I left a bunch of comments just for my education to learn more in-depth on how this planner is supposed to be used by load_state_dict API.
Some of the missing pieces that I imagine the load_state_dict will do to fill some of the gaps of planners are:
- Loading the actual metadata and storage.
- Loading the storage and slice it into the proper size of tensors for the ReadItems.
It will be great if you can elaborate more on the E2E flow on how the high level APIs are interacting with our low level implementation. I also commented this on your design doc.
| def transform_tensor(self, read_item: ReadItem, tensor: torch.Tensor): | ||
| offsets = read_item.dest_offsets | ||
| index = read_item.dest_index | ||
| if index.fqn in self.sharded_state_dict: |
There was a problem hiding this comment.
Can you elaborate this a little bit more?
There was a problem hiding this comment.
Also, I'm confused about why we want to do the narrowing.
There was a problem hiding this comment.
I'll add a comment - the storage layer expects that the tensor returned from resolve_tensor matches the shape of the ReadItem's lengths field, so the tensor must be narrowed here.
| x, XLAShardedTensor) and x.sharding_type != ShardingType.REPLICATED | ||
|
|
||
|
|
||
| def _unwrap_sharded_tensor(x: Any) -> Any: |
There was a problem hiding this comment.
So this is only used for the replicated tensors?
There was a problem hiding this comment.
Correct, that's a good point - the name is confusing. This is used to ensure the default planner is operating on torch.Tensor instead of XLAShardedTensor. Maybe _unwrap_xla_sharded_tensor?
| # Load the checkpoint using the provided load planner | ||
| for p1, p2 in zip(model_in.parameters(), model_out.parameters()): | ||
| self.assertFalse(torch.allclose(p1, p2)) | ||
| dist_cp.load_state_dict( |
There was a problem hiding this comment.
It looks like the LoadPlanner should be able to handle the resharding of the consolidated checkpoints. I wonder how does it handle the case where the save checkpoint has a different sharding spec than the loading model?
There was a problem hiding this comment.
Seems like it's handled by create_read_items_for_chunk_list?
There was a problem hiding this comment.
That's correct - we just need to generate ChunkStorageMetadata for each local shard, and the upstream API create_read_items_for_chunk_list will convert that chunk list into a list of ReadItem. The ReadItems generated will depend on the device mesh used when taking the checkpoint.
There was a problem hiding this comment.
Then is there a simple formula to calculate the number of ReadItems? Or I have to refer to the create_read_items_for_chunk_list algorithm?
There was a problem hiding this comment.
I just added some discussion in another comment, but adding here as well:
The algorithm for create_read_items_for_chunk_list just iterates across all chunks in the storage metadata, identifies overlap with the ChunkStorageMetadata describing our XLAShard, and generates a ReadItem if there is overlap.
| plan = planner.create_local_plan() | ||
| parameter_count = len(list(model.parameters())) | ||
| if self.n_devices > 1: | ||
| # When the model is sharded across devices, fc1.weight will result in |
There was a problem hiding this comment.
So you suggest the number of ReadItem will match the number of XLAShard for a particular XLAShardedTensor. Does this hold true for the following resharding example?
- stored shards <= loading shard, i.e., the models is sharded in 2 ways but need to loaded into 4 ways.
- stored shards > loading shard, i.e., the models is sharded in 4 ways but need to loaded into 2 ways.
I can see that holds true for case 1 but it's hard for me to imagine for case 2. Can you elaborate? Commented in your design as well.
There was a problem hiding this comment.
That's true - I'll clarify the comment that this comes from using the unsharded checkpoint metadata used when creating the Planner in _get_load_planner.
There was a problem hiding this comment.
I'm still curious about case 2 if you can explain to me?
There was a problem hiding this comment.
If we are loading into a coarser mesh, e.g. loading a checkpoint taken on a (4, 4) mesh into a model sharded across (2, 2), each local shard in the (2, 2) mesh will require multiple ReadItems. On our side, we will still generate a single ChunkStorageMetadata for each shard, and the utility function create_read_items_for_chunk_list translates these into the ReadItems necessary by finding the overlap of ChunkStorageMetadata with the chunks in the storage metadata.
The algorithm for create_read_items_for_chunk_list just iterates across all chunks in the storage metadata, identifies overlap with the ChunkStorageMetadata describing our XLAShard, and generates a ReadItem if there is overlap.
There was a problem hiding this comment.
High level speaking, the number of ReadItems will then be determined by the max(number of shards in loading model, number of shards in saving model)?
There was a problem hiding this comment.
We can contrive an example where this won't hold, e.g. saving from a (4, 1) mesh and loading into a (2, 2) mesh would require 8 total ReadItems based on a quick test. It depends a lot on how the storage layer handles the checkpoint, but the worst case number of ReadItems would be O(global shards when saving the model) for each shard being loaded.
There was a problem hiding this comment.
Perfect. Thanks, Jon!
This implements the LoadPlanner interface from torch.distributed.checkpoint. This implementation only directly handles sharded tensors and relies on the default planner's logic for everything else.
A high-level overview of each of the Planner interface methods:
set_up_planner: Called with the state_dict to be restored and metadata from the checkpoint. Our implementation will split the state_dict into a sharded and unsharded portion so that we can defer to the default planner logic for the unsharded part.create_local_plan: ReadItems are generated for every item in the state_dict. The default planner is used for the unsharded objects, and we generate a ReadItem for each shard of each XLAShardedTensor with non-REPLICATED sharding type.create_global_plan: The coordinator process makes any global decisions for the restoration. There is no custom logic here.finish_plan: The process can adjust its plan after global coordination. Again, no custom logic here.load_bytes: This is how non-tensor data is restored. We defer to the default planner's logic.resolve_tensor: This function returns a tensor to store the read result of the associated ReadItem. If the ReadItem doesn't correspond to a sharded tensor, we defer to the default planner logic. Otherwise, we return the unpadded data of the local shard associated with the ReadItem.commit_tensor: This is called after the data has been loaded into the tensor. This is a no-op in the default planner, but for sharded tensors we track each shard that has been committed. Once all shards are committed for a tensor, they are loaded into the XLAShardedTensor.This PR depends on #5128 for some utility functions.