Support sharded safetensors in TF#29350
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
This should be ready for review! cc @ArthurZucker @a8nova |
LysandreJik
left a comment
There was a problem hiding this comment.
Looks good! When would you have in mind an eventual switch to safetensors serialization by default?
| if tf_model._keys_to_ignore_on_load_unexpected is not None: | ||
| for pat in tf_model._keys_to_ignore_on_load_unexpected: | ||
| unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] | ||
| if not skip_logger_warnings: |
There was a problem hiding this comment.
It is, unfortunately! The reason is that this function is used both to load shards and non-sharded checkpoints. When it's loading a non-sharded checkpoint, we want to log missing keys immediately. When it's loading a shard, there will always be lots of "missing" keys, but we don't want to log those - instead, we only want to log keys that are missing from every shard, which we will only know after all shards have been loaded. This is handled in the sharded loading function.
| ): | ||
| all_loading_infos = [] | ||
| for shard in safetensors_shards: | ||
| with safe_open(shard, framework="tf") as safetensors_archive: |
There was a problem hiding this comment.
Shouldn't this load from the PT framework if we're "loading pytorch shards in tensorflow models"?
There was a problem hiding this comment.
safe_open(framework="tf") just loads the tensors as tf.Tensor instead of torch.Tensor - the actual value of the tensor is unchanged. However, we still need to handle weight renaming + transposes, so we still need a pt-to-tf function.
| for p1, p2 in zip(model.weights, ref_model.weights): | ||
| assert np.allclose(p1.numpy(), p2.numpy()) | ||
|
|
||
| @require_safetensors |
There was a problem hiding this comment.
safetensors is now a base dependency so maybe we should eventually just remove all of these
There was a problem hiding this comment.
Makes sense - do you want me to just do it in this PR?
|
@LysandreJik I think now that we have proper support we can switch to safetensors by default immediately, either in this PR or in a follow-up. |
|
I'd wait for a few weeks just to ensure we don't have reports of failure and switch for the next version. WDYT? |
|
Sounds good to me! |
46e5c56 to
efdb604
Compare
|
@LysandreJik is there anything else to be resolved before I merge this? (Except for the failing test, but that's not specific to this PR) |
a1001c8 to
a9f240e
Compare
amyeroberts
left a comment
There was a problem hiding this comment.
Thanks for working on this and enabling sharded support!
Just a few small comments / questions
| # This should not raise even if there are two types of sharded weights | ||
| # This should discard the safetensors weights in favor of the .h5 sharded weights | ||
| TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded") |
There was a problem hiding this comment.
Don't we still want this test to make sure things are backwards compatible for now - I can load sharded h5 files even if safetensor weights are available?
tests/test_modeling_tf_utils.py
Outdated
| # Note: pickle adds some junk so the weight of the file can end up being slightly bigger than | ||
| # the size asked for (since we count parameters) | ||
| if size >= max_size_int + 50000: | ||
| with h5py.File(shard_file, "r") as state_file: |
There was a problem hiding this comment.
What does this represent here?
There was a problem hiding this comment.
Good catch - that was copied from the h5 test, and wouldn't work for safetensors - we just got lucky that it wasn't called in the tests anyway. I removed it!
| mismatched_keys = sum([info["mismatched_keys"] for info in all_loading_infos], []) | ||
|
|
||
| if not skip_logger_warnings: | ||
| if len(unexpected_keys) > 0: |
There was a problem hiding this comment.
AFAICT, all these checks are the same as the ones above. Can we abstract these out to e.g. `validate_keys(unmatched_keys, missing_keys, mismatched_keys) and call that in both of the functions?
| unexpected_keys = sum([info["unexpected_keys"] for info in all_loading_infos], []) | ||
| mismatched_keys = sum([info["mismatched_keys"] for info in all_loading_infos], []) | ||
|
|
||
| if not skip_logger_warnings: |
There was a problem hiding this comment.
My understanding is that we want to have skip_logger_warning=True when calling load_pytorch_state_dict_in_tf2_model here, but I don't see why we're enabling skipping here? Silencing the logging warnings should really be a hidden functionality (maybe with a param _skip_logger_warnings and not something people calling either function use
There was a problem hiding this comment.
You're right, actually - this function is only called once and warnings are always emitted, so the argument isn't needed at all. I removed it!
55519b6 to
e7a2c24
Compare
|
All comments addressed @amyeroberts! I think we should be ready now. |
amyeroberts
left a comment
There was a problem hiding this comment.
Thanks for adding this support!
| return tf_model | ||
|
|
||
|
|
||
| def _log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name): |
There was a problem hiding this comment.
nit - definition should go above the lines of code where it's used i.e. before load_pytorch_state_dict_in_tf2_model
|
|
||
|
|
||
| def tf_shard_checkpoint(weights, max_shard_size="10GB"): | ||
| def tf_shard_checkpoint(weights, max_shard_size="10GB", weights_name: str = TF2_WEIGHTS_NAME): |
There was a problem hiding this comment.
Should we have the default shard size match the one in load_tf_weights?
There was a problem hiding this comment.
I got here by following the torch code, which also has the same issue! Specifically, save_pretrained() has a default size of 5GB, but the actual checkpoint sharding methods have a default size of 10GB. In general, though, the value passed from save_pretrained() will override those values.
It's a very minor detail either way, since I think both 5GB and 10GB shards work fine! We could consider standardizing everything at some point, but I don't think it's a high priority.
| with safe_open(resolved_archive_file, framework="tf") as f: | ||
| safetensors_metadata = f.metadata() | ||
| if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax"]: | ||
| if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax", "mlx"]: |
There was a problem hiding this comment.
Looks like we might want a constant e.g. SUPPORTED_SAFE_FORMATS = ["pt", "tf", "flax", "mlx"] so we don't have to update this in several locations here and for PT (for another PR)
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
|
This looks ready to go, cc @a8nova! We have a branch cut + release planned on Monday, though, and since this touches a lot of core code I don't want to merge it right before a release. Instead, I suggest merging it right after the branch cut, and then we can finalize and merge the PRs that are blocked by it: TF-IDEFICS, TF-Gemma and possibly the Mistral/Mixtral PRs if @ariG23498 can have one of them ready by then (no stress, obviously!) Then we could launch all the new Keras models together in the following release and do a section in the release notes about them, crediting @a8nova and @ariG23498? |
|
cc @a8nova and @ariG23498, this has now been merged. If you rebase your PRs, that should resolve any issues with sharded safetensors loading! |
Right now our TF safetensors loading doesn't support sharded checkpoints, which is a problem as more and more big models move to safetensors weights only! This is currently blocking @a8nova's PR at #26870.
As sharded safetensors saving for TF was also missing, this PR adds that as well, and expands the tests to cover both.
TODO: