Fix load from PT-formatted checkpoint in composite TF models#20661
Fix load from PT-formatted checkpoint in composite TF models#20661
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
ydshieh
left a comment
There was a problem hiding this comment.
This issue is similar to this one, where the fix was implemented in TF composite models' from_encoder_decoder_pretrained.
This PR works (for this issue), but introducing somehow difference between from_pt and safetensors_from_pt (regarding to be before or after load_weight_prefix) - for which I think it's better to treat them equally.
At that time, I wrote
I feel it would be better to modify load_pytorch_weights_in_tf2_model to address this situation, but I tried to avoid modify this Hugging Face's TF core method.
I am going to approve however, as I don't want to change the from_pt part (at least not in this PR), and moving safetensors_from_pt to from_encoder_decoder_pretrained doesn't look clean in the first place (and not super easy neither).
…face#20661) * Fix load from PT-formatted checkpoint in composite TF models * Leave the from_pt part as it was
What does this PR do?
This PR fixes the slow test
TFViT2GPT2EncoderDecoderModelTest::test_real_model_save_load_from_pretrainedwhich was broken by the newsafetensorsintegration. The main problem was that this model loads a GPT-2 as its decoder, which has a safetensors checkpoint formatted in a PyTorch-like format, and that model was loaded with wrong weight names.Moving the variable scope code before we try to load PyTorch-like checkpoints fixes the issued.