Skip to content

Can't load a local finetuned state dict anymore without loading the official pretrained weights first  #16672

@laurahanu

Description

@laurahanu

Environment info

  • transformers version: 4.18.0
  • Platform: Ubuntu & Mac
  • Python version: 3.9.7

Who can help

@sgugger

Information

Issue first reported here
Model I am using (Bert, XLNet ...): Bert, Roberta

The problem arises when using:

  • the official example scripts: (give details below)
  • my own modified scripts: (give details below)

The tasks I am working on is:

  • an official GLUE/SQUaD task: (give the name)
  • my own task or dataset: (give details below)

To reproduce

Steps to reproduce the behavior:

The code below worked before version 4.18.0.

  1. cannot load a finetuned state dict (can download from here) without loading the official pretrained HF weights (which worked by having pretrained_model_name_or_path as None):
    model = RobertaForSequenceClassification.from_pretrained(
        pretrained_model_name_or_path=None,
        config="roberta-base",
        num_labels=16,
        state_dict=state_dict,
    )

Stack trace:

Exception has occurred: TypeError (note: full exception trace is shown but execution is paused at: _run_module_as_main)
expected str, bytes or os.PathLike object, not NoneType
  File "[/detoxify/toxic-env/lib/python3.9/site-packages/torch/serialization.py]()", line 308, in _check_seekable
    f.seek(f.tell())

During handling of the above exception, another exception occurred:

  File "[/detoxify/toxic-env/lib/python3.9/site-packages/transformers/modeling_utils.py]()", line 349, in load_state_dict
    return torch.load(checkpoint_file, map_location="cpu")
  File "[/detoxify/toxic-env/lib/python3.9/site-packages/torch/serialization.py]()", line 594, in load
    with _open_file_like(f, 'rb') as opened_file:
  File "[/detoxify/toxic-env/lib/python3.9/site-packages/torch/serialization.py]()", line 235, in _open_file_like
    return _open_buffer_reader(name_or_buffer)
  File "[/detoxify/toxic-env/lib/python3.9/site-packages/torch/serialization.py]()", line 220, in __init__
    _check_seekable(buffer)
  File "[/detoxify/toxic-env/lib/python3.9/site-packages/torch/serialization.py]()", line 311, in _check_seekable
    raise_err_msg(["seek", "tell"], e)
  File "[/detoxify/toxic-env/lib/python3.9/site-packages/torch/serialization.py]()", line 304, in raise_err_msg
    raise type(e)(msg)

Expected behavior

This seems to only be an issue since #16343 was introduced and seems to be related to this change
(L1444-R1796)
What would solve this would be to have if not is_sharded and state_dict is None: on L1797.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions