Safetensors serialization by default#27064
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
4398da0 to
e4ebba7
Compare
f86a14f to
66a896a
Compare
a680bd1 to
4c09b62
Compare
| if ( | ||
| is_safetensors_available() | ||
| and isinstance(resolved_archive_file, str) | ||
| and resolved_archive_file.endswith(".safetensors") | ||
| ): | ||
| with safe_open(resolved_archive_file, framework="pt") as f: | ||
| metadata = f.metadata() | ||
|
|
||
| if metadata.get("format") == "pt": | ||
| pass | ||
| elif metadata.get("format") == "tf": | ||
| from_tf = True | ||
| logger.info("A TensorFlow safetensors file is being loaded in a PyTorch model.") | ||
| elif metadata.get("format") == "flax": | ||
| from_flax = True | ||
| logger.info("A Flax safetensors file is being loaded in a PyTorch model.") | ||
| else: | ||
| raise ValueError( | ||
| f"Incompatible safetensors file. File metadata is not ['pt', 'tf', 'flax'] but {metadata.get('format')}" | ||
| ) | ||
|
|
||
| from_pt = not (from_tf | from_flax) |
There was a problem hiding this comment.
this is necessary to enable loading safetensors files saved from TensorFlow/Jax into PyTorch models
Narsil
left a comment
There was a problem hiding this comment.
Very nice !
Not that bad for such a big change.
For the testing part I see loading in PT from PT/TF/Flax, but not the other ways
TF -TF
TF - Flax
TF - Pt
Flax - Flax
Flax - TF
Flax - Pt.
From you initial comment I understand it's not possible, but it's not entirely clear for me as to why (you mention sharded weights, is it the only restriction? If yes, from what I read it should be okay-ish to be able to at least load for those, no ?)
sanchit-gandhi
left a comment
There was a problem hiding this comment.
Thanks for working on this @LysandreJik!
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
|
@Narsil, this is what is currently supported and not supported:
I mention this in the PR description:
It should be pretty straightforward to enable it, but I suspect extremely little usage for a TF <> Flax conversion where no PyTorch conversion exists. I'm planning to add this to the documentation and IMO we can work on it afterwards if there are requests. |
amyeroberts
left a comment
There was a problem hiding this comment.
Very nice! 🔥
Just some small nits and Qs for my own understanding
src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py
Show resolved
Hide resolved
| def test_safetensors_flax_from_sharded_msgpack_with_sharded_safetensors_hub(self): | ||
| # This should not raise even if there are two types of sharded weights | ||
| # This should discard the safetensors weights in favor of the msgpack sharded weights | ||
| FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-safetensors-msgpack-sharded") |
There was a problem hiding this comment.
Just to make sure I've understood correctly: TF and Flax models can't load sharded weights from safetensors. So, if this passes, we know the model has successfully loaded the msgpack sharded weights?
There was a problem hiding this comment.
Yes that's exactly right! This was raised by Sanchit as a previous version of the implementation priorized safetensors, realized they were sharded, and errored-out; but if sharded msgpack are also in the repo, we would want to load these first
There was a problem hiding this comment.
Thanks for explaining!
| model = cls(config, *model_args, _do_init=_do_init, **model_kwargs) | ||
|
|
||
| if from_pt: | ||
| if from_pt or safetensors_from_pt: |
There was a problem hiding this comment.
Exposing my lack of knowledge about safe tensors here: if safetensors_from_pt is True here then is the reason we do load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file, is_sharded) because the serialized weights are in "pytorch format" and therefore can't be loaded using cls.load_flax_weights?
There was a problem hiding this comment.
Yes, that's correct! This way we call load_pytorch_checkpoint_in_flax_state_dict with the safetensors file, and that method checks if the file ends with .safetensors to load it the pytorch-way
There was a problem hiding this comment.
Thanks for explaining!
| with safe_open(resolved_archive_file, framework="tf") as safetensors_archive: | ||
| mismatched_layers = [] | ||
| weight_names = [format_weight_name(w.name, _prefix=_prefix) for w in model.weights] | ||
| weight_names = [strip_model_name_and_prefix(w.name, _prefix=_prefix) for w in model.weights] |
There was a problem hiding this comment.
As:
- The previous function
format_weight_nameadded_prefix, whereas the new functionstrip_model_name_and_prefixremoves_prefixif it's present. weight_namesare compared to those insafetensors_archive
Can we load in previously saved safetensors (before this PR) into our TF models?
There was a problem hiding this comment.
@amyeroberts I believe the previous code was just completely incorrect! Essentially, the relevant workflow (saving TF composite encoder-decoder models as safetensors and then reloading the checkpoint in TF) was not being tested, and actually didn't work because of the name prefix bug.
As such, I don't think there's a backwards compatibility issue here, because previous checkpoints weren't working at all. My suspicion is that not many people were saving encoder-decoder models in TF, and not many TF users were saving safetensors, and so the intersection of that venn diagram was tiny enough that no-one noticed the bug for a long time!
There was a problem hiding this comment.
Also, just to clarify: _prefix is almost always None or "" when this function is called, in which case the behaviour is unchanged after this bugfix. _prefix is only defined when loading composite models like EncoderDecoder, which is the workflow that was broken before this.
There was a problem hiding this comment.
Thanks for explaining!
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
|
I will proceed to merge this and write a small explanatory doc tomorrow. I would like for the slow tests to run on this before the release. |
|
Awesome ! Thanks a LOT for this. |
* Safetensors serialization by default * First pass on the tests * Second pass on the tests * Third pass on the tests * Fix TF weight loading from TF-format safetensors * Specific encoder-decoder fixes for weight crossloading * Add VisionEncoderDecoder fixes for TF too * Change filename test for pt-to-tf * One missing fix for TFVisionEncoderDecoder * Fix the other crossload test * Support for flax + updated tests * Apply suggestions from code review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Sanchit's comments * Sanchit's comments 2 * Nico's comments * Fix tests * cleanup * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: Matt <rocketknight1@gmail.com> Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This PR aims to do one thing but is larger than expected. I'm happy to break it down into smaller PRs if it helps for reviewing.
This PR aims to switch safe serialization to
Trueby default fortorchmodels. In doing so, it revealed a few bugs in the existing implementation andsafetensorssupport that this PR fixes.Additionally, support for
safetensorsfor Flax models is added so that models saved from PyTorch after merging this PR can be used in both TensorFlow and Flax, and for models saved from TensorFlow/Flax to be loaded in PyTorch models.The following should be worked on shortly to enable switching to safetensors by default for TensorFlow and Flax as well:
Additionally, I'll contribute some documentation making the following clear:
Thanks, @Rocketknight1, for the help on TensorFlow's side.