Skip to content

main_input_name is None if predict_with_generate in keras_callbacks.py for encoder-decoder(Bert-Bert) TF models #24872

@saichandrapandraju

Description

@saichandrapandraju

System Info

  • transformers version: 4.30.2
  • Platform: Linux-5.15.109+-x86_64-with-glibc2.31
  • Python version: 3.10.12
  • Huggingface_hub version: 0.16.4
  • Safetensors version: 0.3.1
  • PyTorch version (GPU?): 2.0.1+cu118 (True)
  • Tensorflow version (GPU?): 2.12.0 (True)
  • Flax version (CPU?/GPU?/TPU?): 0.7.0 (gpu)
  • Jax version: 0.4.13
  • JaxLib version: 0.4.13
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No

Who can help?

@gante , @Rocketknight1

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

KeyError                                  Traceback (most recent call last)
[<ipython-input-49-089aabb58b9b>](https://localhost:8080/#) in <cell line: 1>()
----> 1 history = model.fit(x=tf_train_set, validation_data=tf_validation_set, epochs=num_epochs, callbacks=callbacks)

1 frames
[/usr/local/lib/python3.10/dist-packages/transformers/keras_callbacks.py](https://localhost:8080/#) in on_epoch_end(self, epoch, logs)
    217             if self.predict_with_generate:
    218                 if isinstance(batch, dict):
--> 219                     generation_inputs = batch[main_input_name]
    220                     attention_mask = batch.get("attention_mask", None)
    221                 else:

KeyError: None

Here's the colab link to reproduce the error.

because of this code in keras_callbacks.py (commented with >>>>> .... <<<<<< for better understanding)-

#### in tf_keras_callback (in func `on_epoch_end`, ~line 191)
main_input_name = None
if self.predict_with_generate:
    # This dense conditional recognizes the case where we have an encoder-decoder model, but
    # avoids getting tangled up when we just have a model with a layer called 'encoder'
    if hasattr(self.model, "encoder") and hasattr(self.model.encoder, "main_input_name"):
        # >>>>>>> If this condition is not satisfied(which is the case currently), `main_input_name` remains None <<<<<<<<
        if self.model.encoder.main_input_name != self.model.main_input_name:
            main_input_name = self.model.encoder.main_input_name
    else:
        main_input_name = getattr(self.model, "main_input_name", "input_ids")

    if self.use_xla_generation and self.generation_function is None:

        def generation_function(inputs, attention_mask):
            return self.model.generate(inputs, attention_mask=attention_mask, **self.generate_kwargs)

        self.generation_function = tf.function(generation_function, jit_compile=True)

prediction_list = []
label_list = []

# The whole predict/generate loop is handled inside this method
for batch in self.eval_dataset:
    if isinstance(batch, tuple):
        batch, labels = batch
    else:
        labels = None
    if self.predict_with_generate:
        if isinstance(batch, dict):
            generation_inputs = batch[main_input_name]  # >>>>>>>>>>>> `main_input_name` remains None here (~line 219) <<<<<<<<<<<<
            attention_mask = batch.get("attention_mask", None)
        else:
            generation_inputs = batch
            attention_mask = None
        if self.use_xla_generation:
            predictions = self.generation_function(generation_inputs, attention_mask=attention_mask)
        else:
            predictions = self.model.generate(
                generation_inputs, attention_mask=attention_mask, **self.generate_kwargs
            )
    

Expected behavior

main_input_name should be input_ids

for that the following function can be modified -

if hasattr(self.model, "encoder") and hasattr(self.model.encoder, "main_input_name"):
        if self.model.encoder.main_input_name != self.model.main_input_name:
            main_input_name = self.model.encoder.main_input_name

to something like -

if hasattr(self.model, "encoder") and hasattr(self.model.encoder, "main_input_name"):
        main_input_name = self.model.encoder.main_input_name

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions