Skip to content

[Llava] Phi text model produces ValueError: Attention mask should be of size (1, 1, 1, 230), but is torch.Size([1, 1, 1, 8]) when using past_key_values in generate #30809

@xenova

Description

@xenova

System Info

  • transformers version: 4.38.2
  • Platform: Linux-6.1.58+-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.23.0
  • Safetensors version: 0.4.3
  • Accelerate version: 0.30.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.2.1+cu121 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): 0.8.3 (cpu)
  • Jax version: 0.4.26
  • JaxLib version: 0.4.26
  • Using GPU in script?: no
  • Using distributed or parallel set-up in script?: no

Who can help?

@gante (generate) @susnato (phi implementation) @younesbelkada (llava implementation)

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Following the multi-round conversation tutorial from here, I put together this minimal reproduction to show how switching Llava to use a Phi text model (instead of e.g., llama) results in an error when reusing past key values.

Running:

from PIL import Image
import requests
from transformers import AutoProcessor, LlavaForConditionalGeneration

# Load model and processor

# THIS WORKS
# model_id = "Xenova/tiny-random-LlavaForConditionalGeneration"
# model = LlavaForConditionalGeneration.from_pretrained(model_id)

# THIS DOESN'T WORK
model_id = "Xenova/tiny-random-LlavaForConditionalGeneration_phi"
model = LlavaForConditionalGeneration.from_pretrained(model_id, attn_implementation="eager")

processor = AutoProcessor.from_pretrained(model_id)

# Define inputs
prompt = "<image>Hi"
url = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/white-image.png?download=true"
image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(text=prompt, images=image,
                   return_tensors="pt", padding=True)

# Generate w/o past_key_values
output = model.generate(
  **inputs,
  max_new_tokens=3,
  return_dict_in_generate=True,
  do_sample=False,
)

decoded = processor.batch_decode(
    output["sequences"], skip_special_tokens=False)

# Prepare new inputs
new_inputs = processor(decoded, return_tensors="pt", padding=True)

# Generate w/ past_key_values
generate_ids = model.generate(
    **new_inputs,
    do_sample=False,
    past_key_values=output['past_key_values'],
    max_new_tokens=20,
)
print(f'{generate_ids=}')

decoded2 = processor.batch_decode(
    generate_ids, skip_special_tokens=False)
print(f'{decoded2=}')

results in this error

Traceback (most recent call last):
  File "/content/transformers.js/../test.py", line 39, in <module>
    generate_ids = model.generate(
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 1544, in generate
    return self.greedy_search(
  File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 2404, in greedy_search
    outputs = self(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/llava/modeling_llava.py", line 469, in forward
    outputs = self.language_model(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/phi/modeling_phi.py", line 1046, in forward
    outputs = self.model(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/phi/modeling_phi.py", line 925, in forward
    layer_outputs = decoder_layer(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/phi/modeling_phi.py", line 666, in forward
    attn_outputs, self_attn_weights, present_key_value = self.self_attn(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/phi/modeling_phi.py", line 375, in forward
    raise ValueError(
ValueError: Attention mask should be of size (1, 1, 1, 230), but is torch.Size([1, 1, 1, 8])

Expected behavior

If you try with a llama model (e.g., here; see comments) it works correctly.

Metadata

Metadata

Assignees

No one assigned

    Labels

    GenerationWIPLabel your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions