Skip to content

Beam search genereation with len(eos_token_id) > 1 throws exceptions  #25103

@yonigottesman

Description

@yonigottesman

System Info

  • transformers version: 4.31.0
  • Platform: Linux-5.13.0-1031-aws-x86_64-with-glibc2.31
  • Python version: 3.10.5
  • Huggingface_hub version: 0.16.4
  • Safetensors version: 0.3.1
  • Accelerate version: not installed
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.0.1+cu117 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: true
  • Using distributed or parallel set-up in script?: false

Who can help?

@gante

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

import transformers
from transformers import GenerationConfig
import torch

name = "gpt2"
tokenizer = transformers.AutoTokenizer.from_pretrained(name)
model = transformers.AutoModelForCausalLM.from_pretrained(name)

gc = GenerationConfig(
    max_new_tokens=40,
    eos_token_id=tokenizer.encode(" black white green red brown blue yellow purple pink orange"),
    pad_token_id=tokenizer.eos_token_id,
    num_beams=3,
)

input_ids = tokenizer.encode("Hello, I have 3 cats, one of them is colored", return_tensors="pt")
output = model.generate(input_ids, generation_config=gc)
tokenizer.decode(output[0])

Expected behavior

This simple beam search example should work but is throwing this exception:

File [/usr/local/lib/python3.10/site-packages/transformers/generation/utils.py:2985](https://vscode-remote+attached-002dcontainer-002b7b22636f6e7461696e65724e616d65223a222f6c756e675f63616e636572222c2273657474696e6773223a7b22686f7374223a227373683a2f2f796f6e69676f5f6770227d7d.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/site-packages/transformers/generation/utils.py:2985), in GenerationMixin.beam_search(self, input_ids, beam_scorer, logits_processor, stopping_criteria, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, **model_kwargs)
   2982 next_tokens = next_tokens % vocab_size
   2984 # stateless
-> 2985 beam_outputs = beam_scorer.process(
   2986     input_ids,
   2987     next_token_scores,
   2988     next_tokens,
   2989     next_indices,
   2990     pad_token_id=pad_token_id,
   2991     eos_token_id=eos_token_id,
   2992     beam_indices=beam_indices,
   2993 )
   2995 beam_scores = beam_outputs["next_beam_scores"]
   2996 beam_next_tokens = beam_outputs["next_beam_tokens"]

File [/usr/local/lib/python3.10/site-packages/transformers/generation/beam_search.py:297](https://vscode-remote+attached-002dcontainer-002b7b22636f6e7461696e65724e616d65223a222f6c756e675f63616e636572222c2273657474696e6773223a7b22686f7374223a227373683a2f2f796f6e69676f5f6770227d7d.vscode-resource.vscode-cdn.net/usr/local/lib/python3.10/site-packages/transformers/generation/beam_search.py:297), in BeamSearchScorer.process(self, input_ids, next_scores, next_tokens, next_indices, pad_token_id, eos_token_id, beam_indices, group_index)
    294         break
    296 if beam_idx < self.group_size:
--> 297     raise ValueError(
    298         f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id:"
    299         f" {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
    300     )
    302 # Check if we are done so that we can save a pad step if all(done)
    303 self._done[batch_group_idx] = self._done[batch_group_idx] or self._beam_hyps[batch_group_idx].is_done(
    304     next_scores[batch_idx].max().item(), cur_len
    305 )
ValueError: At most 3 tokens in tensor([ 2266, 11398,  4171,  4077,    11, 10912]) can be equal to `eos_token_id: [2042, 2330, 4077, 2266, 7586, 4171, 7872, 14032, 11398, 10912]`. Make sure tensor([ 2266, 11398,  4171,  4077,    11, 10912]) are corrected.

I think there is a bug in the check if beam_idx < self.group_size as it doesn't take into account that there could be more than 1 eos_token_id and each beam may select more than 1 eos token after the topk.

I will be happy to work on this

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions