Skip to content

Update PT/TF weight conversion after #24030#24547

Merged
ydshieh merged 3 commits into
mainfrom
oh_my_weight
Jun 28, 2023
Merged

Update PT/TF weight conversion after #24030#24547
ydshieh merged 3 commits into
mainfrom
oh_my_weight

Conversation

@ydshieh

@ydshieh ydshieh commented Jun 28, 2023

Copy link
Copy Markdown
Collaborator

What does this PR do?

Update PT/TF weight conversion due to the change in #24030.

(can do PT/Flax too in the same PR, but request a review first anyway)

Code snippet to show issues and verify this PR's effect

(Failing for main + nightly torch. Pass for PR + nightly torch and main/PR + stable torch)

import transformers

from transformers import TFWav2Vec2Model
from tests.models.wav2vec2.test_modeling_tf_wav2vec2 import TFWav2Vec2ModelTest

self = TFWav2Vec2ModelTest()
self.setUp()

model_class = TFWav2Vec2Model
allow_missing_keys = False

config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

pt_model_class_name = model_class.__name__[2:]  # Skip the "TF" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)

tf_model = model_class(config)
pt_model = pt_model_class(config)

tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class)

# Check we can load pt model in tf and vice-versa with model => model functions
try:
    _tf_model = transformers.load_pytorch_model_in_tf2_model(
        tf_model, pt_model, tf_inputs=tf_inputs_dict, allow_missing_keys=allow_missing_keys
    )
except:
    _tf_model = None

try:
    _pt_model = transformers.load_tf2_model_in_pytorch_model(
        pt_model, tf_model, allow_missing_keys=allow_missing_keys
    )
except:
    _pt_model = None

if _tf_model is None:
    print("_tf_model fails")
else:
    print("_tf_model OK")

if _pt_model is None:
    print("_pt_model fails")
else:
    print("_pt_model OK")

Comment thread src/transformers/modeling_tf_pytorch_utils.py Outdated
Comment thread src/transformers/modeling_tf_pytorch_utils.py
Comment thread src/transformers/modeling_tf_pytorch_utils.py
@ydshieh ydshieh requested a review from sgugger June 28, 2023 13:10
@HuggingFaceDocBuilderDev

HuggingFaceDocBuilderDev commented Jun 28, 2023

Copy link
Copy Markdown

The documentation is not available anymore as the PR was closed or merged.

@sgugger sgugger left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this! Just a few comments on the variable names.

Comment thread src/transformers/modeling_tf_pytorch_utils.py Outdated
Comment thread src/transformers/modeling_tf_pytorch_utils.py Outdated
@ydshieh ydshieh requested a review from sgugger June 28, 2023 14:16
@ydshieh ydshieh merged commit 6c57ce1 into main Jun 28, 2023
@ydshieh ydshieh deleted the oh_my_weight branch June 28, 2023 14:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants