Support save/load ckpt for XLA FSDP#32311
Conversation
src/transformers/trainer.py
Outdated
| ckpt_suffix=f"rank*_of_*_{WEIGHTS_NAME}.pth", | ||
| save_model=False, | ||
| ) | ||
| assert isinstance(model, FSDP) |
There was a problem hiding this comment.
We don't do asserts in the codebase. If we worry about if this is not a thing, we should do an if not + raise a proper and clear error
There was a problem hiding this comment.
OK, I remove this check since the model should be FSDP XLA when self.is_fsdp_xla_enabled and not self.is_fsdp_xla_v2_enabled.
src/transformers/trainer.py
Outdated
| xm.save(full_state_dict, os.path.join(output_dir, WEIGHTS_NAME)) | ||
| # Remove temporary sharded checkpoints | ||
| xm.rendezvous("remove_unused_checkpoints") | ||
| os.remove(ckpt_path) |
There was a problem hiding this comment.
I'm weary about doing this, running into race conditions and the like. Can we not do this?
There was a problem hiding this comment.
I think we can retain the shared checkpoints for XLA users to use for their own inference or other scenarios. By keeping the save_pretrained logic, the resume from checkpoint functionality for FSDP can still work.
There was a problem hiding this comment.
BTW, we will later add logic to handle FSDP state dicts, similar to those in PyTorch, in torch-xla to better facilitate saving and loading. However, for now, we need to at least save the complete weights for the users.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
hi, @muellerzr , can you help take another look at this PR? I've made some necessary modifications. |
muellerzr
left a comment
There was a problem hiding this comment.
Thanks! Overall this looks good to me bar one nit
|
Passing off to @ArthurZucker for final :) |
src/transformers/trainer.py
Outdated
| if is_torch_xla_available(): | ||
| xm.rendezvous("saving_optimizer_states") | ||
| xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) | ||
| if self.is_fsdp_xla_enabled and not self.is_fsdp_xla_v2_enabled: |
There was a problem hiding this comment.
What does it mean if self.is_fsdp_xla_enabled is True and self.is_fsdp_xla_v2_enabled is False? Is it possible to have all possible combinations of these flags? e.g.:
self.is_fsdp_xla_enabled=True, self.is_fsdp_xla_v2_enabled=False
self.is_fsdp_xla_enabled=True, self.is_fsdp_xla_v2_enabled=True
self.is_fsdp_xla_enabled=False, self.is_fsdp_xla_v2_enabled=False
self.is_fsdp_xla_enabled=False, self.is_fsdp_xla_v2_enabled=True
As this check is repeated several times below - it's be good to combine into a single, explicitly named flag e.g. fsdp_v1_enabled
There was a problem hiding this comment.
Good suggestion, I have already changed to self.is_fsdp_xla_v1_enabled. self.is_fsdp_xla_v1_enabled and self.is_fsdp_xla_v2_enabled are different versions of the fsdp xla implementation.
amyeroberts
left a comment
There was a problem hiding this comment.
Thanks for adding this support!
What does this PR do?
Fixes #32310
Currently, FSDP on XLA only saves the sharded weights for rank 0. This PR enables the saving of complete weights, and also supports the functionality of resuming from a checkpoint. Since FSDP on XLA currently does not support loading a full optimizer into a sharded optimizer, this PR only saves the sharded optimizer and some related sharded metadata to facilitate subsequent loading. With this PR and a corresponding PR for accelerate, FSDP XLA can now support the functionality of resuming from a checkpoint.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@muellerzr and @SunMarc