-
Notifications
You must be signed in to change notification settings - Fork 32.5k
Closed
Description
System Info
transformersversion: 4.30.2- Platform: Linux-5.15.107+-x86_64-with-glibc2.31
- Python version: 3.10.12
- Huggingface_hub version: 0.15.1
- Safetensors version: 0.3.1
- PyTorch version (GPU?): 2.0.1+cu118 (False)
- Tensorflow version (GPU?): 2.12.0 (False)
- Flax version (CPU?/GPU?/TPU?): 0.6.11 (cpu)
- Jax version: 0.4.10
- JaxLib version: 0.4.10
- Using GPU in script?: No
- Using distributed or parallel set-up in script?: No
Who can help?
No response
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
I'm trying to write my own decoding logic so I can export to TFLite (the app runs decoding logic itself, calling into the tflite model with past_key_values and input_ids but the code for that is a little more involved)
I'm not sure if I'm missing something important here but I was able to successfully export Whisper before with this sort of pattern
I've reduced the problem to this example:
import tensorflow as tf
from transformers import AutoTokenizer, TFOPTForCausalLM, TFGPT2LMHeadModel
def decoding_example(model, tokenizer):
input_ids = tf.convert_to_tensor([[1]]) * int(tokenizer.bos_token_id)
outputs = model(input_ids, return_dict=True, use_cache=True, past_key_values=None)
past_key_values = outputs.past_key_values
max_new_tokens = 8
for i in range(max_new_tokens):
print(i)
decoded_next_token = 123 # just an example, this would depend on outputs.last_hidden_state
input_ids = tf.convert_to_tensor([[1]]) * decoded_next_token
outputs = model(input_ids, return_dict=True, use_cache=True, past_key_values=past_key_values)
past_key_values = outputs.past_key_values
print("Finished, all OK")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
model = TFOPTForCausalLM.from_pretrained("facebook/opt-125m")
decoding_example(model, tokenizer) # failsOutput
0
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-5-07105bf5f115> in <cell line: 4>()
2 model = TFOPTForCausalLM.from_pretrained("facebook/opt-125m")
3
----> 4 decoding_example(model, tokenizer) # fails
9 frames
<ipython-input-3-94ad2e4e3e50> in decoding_example(model, tokenizer)
11 input_ids = tf.convert_to_tensor([[1]]) * decoded_next_token
12
---> 13 outputs = model(input_ids, return_dict=True, use_cache=True, past_key_values=past_key_values)
14 past_key_values = outputs.past_key_values
15
/usr/local/lib/python3.10/dist-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
68 # To get the full stack trace, call:
69 # `tf.debugging.disable_traceback_filtering()`
---> 70 raise e.with_traceback(filtered_tb) from None
71 finally:
72 del filtered_tb
/usr/local/lib/python3.10/dist-packages/transformers/modeling_tf_utils.py in run_call_with_unpacked_inputs(self, *args, **kwargs)
440
441 unpacked_inputs = input_processing(func, config, **fn_args_and_kwargs)
--> 442 return func(self, **unpacked_inputs)
443
444 # Keras enforces the first layer argument to be passed, and checks it through `inspect.getfullargspec()`. This
/usr/local/lib/python3.10/dist-packages/transformers/models/opt/modeling_tf_opt.py in call(self, input_ids, past_key_values, attention_mask, position_ids, head_mask, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, training, **kwargs)
956 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
957
--> 958 outputs = self.model(
959 input_ids=input_ids,
960 past_key_values=past_key_values,
/usr/local/lib/python3.10/dist-packages/transformers/modeling_tf_utils.py in run_call_with_unpacked_inputs(self, *args, **kwargs)
440
441 unpacked_inputs = input_processing(func, config, **fn_args_and_kwargs)
--> 442 return func(self, **unpacked_inputs)
443
444 # Keras enforces the first layer argument to be passed, and checks it through `inspect.getfullargspec()`. This
/usr/local/lib/python3.10/dist-packages/transformers/models/opt/modeling_tf_opt.py in call(self, input_ids, attention_mask, head_mask, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, training, **kwargs)
730 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
731
--> 732 outputs = self.decoder(
733 input_ids,
734 attention_mask=attention_mask,
/usr/local/lib/python3.10/dist-packages/transformers/modeling_tf_utils.py in run_call_with_unpacked_inputs(self, *args, **kwargs)
440
441 unpacked_inputs = input_processing(func, config, **fn_args_and_kwargs)
--> 442 return func(self, **unpacked_inputs)
443
444 # Keras enforces the first layer argument to be passed, and checks it through `inspect.getfullargspec()`. This
/usr/local/lib/python3.10/dist-packages/transformers/models/opt/modeling_tf_opt.py in call(self, input_ids, inputs_embeds, attention_mask, head_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict, training)
657 past_key_value = past_key_values[idx] if past_key_values is not None else None
658
--> 659 hidden_states, layer_self_attn, present_key_value = decoder_layer(
660 hidden_states,
661 attention_mask=attention_mask,
/usr/local/lib/python3.10/dist-packages/transformers/models/opt/modeling_tf_opt.py in call(self, hidden_states, attention_mask, layer_head_mask, past_key_value, training, output_attentions, use_cache)
323
324 # add present self-attn cache to positions 1,2 of present_key_value tuple
--> 325 hidden_states, self_attn_weights, present_key_value = self.self_attn(
326 hidden_states=hidden_states,
327 past_key_value=self_attn_past_key_value,
/usr/local/lib/python3.10/dist-packages/transformers/models/opt/modeling_tf_opt.py in call(self, hidden_states, key_value_states, past_key_value, attention_mask, layer_head_mask, training)
217
218 if attention_mask is not None:
--> 219 tf.debugging.assert_equal(
220 shape_list(attention_mask),
221 [bsz, 1, tgt_len, src_len],
InvalidArgumentError: Exception encountered when calling layer 'self_attn' (type TFOPTAttention).
Attention mask should be of size (1, 1, 0, 1), but is [1, 1, 1, 2]
Condition x == y did not hold.
Indices of first 2 different values:
[[2]
[3]]
Corresponding x values:
[1 2]
Corresponding y values:
[0 1]
First 3 elements of x:
[1 1 1]
First 3 elements of y:
[1 1 0]
Call arguments received by layer 'self_attn' (type TFOPTAttention):
• hidden_states=tf.Tensor(shape=(1, 0, 768), dtype=float32)
• key_value_states=None
• past_key_value=('tf.Tensor(shape=(1, 12, 1, 64), dtype=float32)', 'tf.Tensor(shape=(1, 12, 1, 64), dtype=float32)')
• attention_mask=tf.Tensor(shape=(1, 1, 1, 2), dtype=float32)
• layer_head_mask=None
• training=False
Expected behavior
I expect it to work like it does with GPT2
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
model = TFGPT2LMHeadModel.from_pretrained("distilgpt2")
decoding_example(model, tokenizer) # worksReactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels