Skip to content

TFOPTForCausalLM Attention mask size mismatch exception #24637

@abb128

Description

@abb128

System Info

  • transformers version: 4.30.2
  • Platform: Linux-5.15.107+-x86_64-with-glibc2.31
  • Python version: 3.10.12
  • Huggingface_hub version: 0.15.1
  • Safetensors version: 0.3.1
  • PyTorch version (GPU?): 2.0.1+cu118 (False)
  • Tensorflow version (GPU?): 2.12.0 (False)
  • Flax version (CPU?/GPU?/TPU?): 0.6.11 (cpu)
  • Jax version: 0.4.10
  • JaxLib version: 0.4.10
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: No

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I'm trying to write my own decoding logic so I can export to TFLite (the app runs decoding logic itself, calling into the tflite model with past_key_values and input_ids but the code for that is a little more involved)

I'm not sure if I'm missing something important here but I was able to successfully export Whisper before with this sort of pattern

I've reduced the problem to this example:

Colab Link

import tensorflow as tf
from transformers import AutoTokenizer, TFOPTForCausalLM, TFGPT2LMHeadModel

def decoding_example(model, tokenizer):
  input_ids = tf.convert_to_tensor([[1]]) * int(tokenizer.bos_token_id)
  outputs = model(input_ids, return_dict=True, use_cache=True, past_key_values=None)

  past_key_values = outputs.past_key_values
  max_new_tokens = 8
  for i in range(max_new_tokens):
    print(i)
    decoded_next_token = 123 # just an example, this would depend on outputs.last_hidden_state

    input_ids = tf.convert_to_tensor([[1]]) * decoded_next_token

    outputs = model(input_ids, return_dict=True, use_cache=True, past_key_values=past_key_values)
    past_key_values = outputs.past_key_values
  
  print("Finished, all OK")

tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
model = TFOPTForCausalLM.from_pretrained("facebook/opt-125m")

decoding_example(model, tokenizer) # fails
Output
0
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-5-07105bf5f115> in <cell line: 4>()
      2 model = TFOPTForCausalLM.from_pretrained("facebook/opt-125m")
      3 
----> 4 decoding_example(model, tokenizer) # fails

9 frames
<ipython-input-3-94ad2e4e3e50> in decoding_example(model, tokenizer)
     11     input_ids = tf.convert_to_tensor([[1]]) * decoded_next_token
     12 
---> 13     outputs = model(input_ids, return_dict=True, use_cache=True, past_key_values=past_key_values)
     14     past_key_values = outputs.past_key_values
     15 

/usr/local/lib/python3.10/dist-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
     68             # To get the full stack trace, call:
     69             # `tf.debugging.disable_traceback_filtering()`
---> 70             raise e.with_traceback(filtered_tb) from None
     71         finally:
     72             del filtered_tb

/usr/local/lib/python3.10/dist-packages/transformers/modeling_tf_utils.py in run_call_with_unpacked_inputs(self, *args, **kwargs)
    440 
    441         unpacked_inputs = input_processing(func, config, **fn_args_and_kwargs)
--> 442         return func(self, **unpacked_inputs)
    443 
    444     # Keras enforces the first layer argument to be passed, and checks it through `inspect.getfullargspec()`. This

/usr/local/lib/python3.10/dist-packages/transformers/models/opt/modeling_tf_opt.py in call(self, input_ids, past_key_values, attention_mask, position_ids, head_mask, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, training, **kwargs)
    956         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    957 
--> 958         outputs = self.model(
    959             input_ids=input_ids,
    960             past_key_values=past_key_values,

/usr/local/lib/python3.10/dist-packages/transformers/modeling_tf_utils.py in run_call_with_unpacked_inputs(self, *args, **kwargs)
    440 
    441         unpacked_inputs = input_processing(func, config, **fn_args_and_kwargs)
--> 442         return func(self, **unpacked_inputs)
    443 
    444     # Keras enforces the first layer argument to be passed, and checks it through `inspect.getfullargspec()`. This

/usr/local/lib/python3.10/dist-packages/transformers/models/opt/modeling_tf_opt.py in call(self, input_ids, attention_mask, head_mask, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, training, **kwargs)
    730         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    731 
--> 732         outputs = self.decoder(
    733             input_ids,
    734             attention_mask=attention_mask,

/usr/local/lib/python3.10/dist-packages/transformers/modeling_tf_utils.py in run_call_with_unpacked_inputs(self, *args, **kwargs)
    440 
    441         unpacked_inputs = input_processing(func, config, **fn_args_and_kwargs)
--> 442         return func(self, **unpacked_inputs)
    443 
    444     # Keras enforces the first layer argument to be passed, and checks it through `inspect.getfullargspec()`. This

/usr/local/lib/python3.10/dist-packages/transformers/models/opt/modeling_tf_opt.py in call(self, input_ids, inputs_embeds, attention_mask, head_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict, training)
    657             past_key_value = past_key_values[idx] if past_key_values is not None else None
    658 
--> 659             hidden_states, layer_self_attn, present_key_value = decoder_layer(
    660                 hidden_states,
    661                 attention_mask=attention_mask,

/usr/local/lib/python3.10/dist-packages/transformers/models/opt/modeling_tf_opt.py in call(self, hidden_states, attention_mask, layer_head_mask, past_key_value, training, output_attentions, use_cache)
    323 
    324         # add present self-attn cache to positions 1,2 of present_key_value tuple
--> 325         hidden_states, self_attn_weights, present_key_value = self.self_attn(
    326             hidden_states=hidden_states,
    327             past_key_value=self_attn_past_key_value,

/usr/local/lib/python3.10/dist-packages/transformers/models/opt/modeling_tf_opt.py in call(self, hidden_states, key_value_states, past_key_value, attention_mask, layer_head_mask, training)
    217 
    218         if attention_mask is not None:
--> 219             tf.debugging.assert_equal(
    220                 shape_list(attention_mask),
    221                 [bsz, 1, tgt_len, src_len],

InvalidArgumentError: Exception encountered when calling layer 'self_attn' (type TFOPTAttention).

Attention mask should be of size (1, 1, 0, 1), but is [1, 1, 1, 2]
Condition x == y did not hold.
Indices of first 2 different values:
[[2]
 [3]]
Corresponding x values:
[1 2]
Corresponding y values:
[0 1]
First 3 elements of x:
[1 1 1]
First 3 elements of y:
[1 1 0]

Call arguments received by layer 'self_attn' (type TFOPTAttention):
  • hidden_states=tf.Tensor(shape=(1, 0, 768), dtype=float32)
  • key_value_states=None
  • past_key_value=('tf.Tensor(shape=(1, 12, 1, 64), dtype=float32)', 'tf.Tensor(shape=(1, 12, 1, 64), dtype=float32)')
  • attention_mask=tf.Tensor(shape=(1, 1, 1, 2), dtype=float32)
  • layer_head_mask=None
  • training=False

Expected behavior

I expect it to work like it does with GPT2

tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
model = TFGPT2LMHeadModel.from_pretrained("distilgpt2")

decoding_example(model, tokenizer) # works

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions