-
Notifications
You must be signed in to change notification settings - Fork 32.4k
Description
Feature request
I would like to change the output returned by sample_beam when num_return_sequences > 1. The current implementation executes the same process num_return_sequences times.
transformers/src/transformers/generation/utils.py
Lines 1692 to 1707 in 080a971
| beam_scorer = BeamSearchScorer( | |
| batch_size=batch_size * generation_config.num_return_sequences, | |
| num_beams=generation_config.num_beams, | |
| device=inputs_tensor.device, | |
| length_penalty=generation_config.length_penalty, | |
| do_early_stopping=generation_config.early_stopping, | |
| max_length=generation_config.max_length, | |
| ) | |
| # 13. interleave input_ids with `num_beams` additional sequences per batch | |
| input_ids, model_kwargs = self._expand_inputs_for_generation( | |
| input_ids=input_ids, | |
| expand_size=generation_config.num_beams * generation_config.num_return_sequences, | |
| is_encoder_decoder=self.config.is_encoder_decoder, | |
| **model_kwargs, | |
| ) |
Therefore, when the following code is executed, 3 out of 5 will produce exactly the same output.
code:
from transformers import AutoTokenizer, AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
inputs = tokenizer(["The full name of Donald is Donald"], return_tensors="pt")
outputs = model.generate(**inputs, num_beams=5, num_return_sequences=5,do_sample=True)
print("\n".join(tokenizer.batch_decode(outputs, skip_special_tokens=True)))output:
The full name of Donald is Donald J. Trump Jr., the son-in-law and senior
The full name of Donald is Donald J. Trump Jr., the president's son-in-law
The full name of Donald is Donald J. Trump Jr., the president's son-in-law
The full name of Donald is Donald J. Trump Jr., the son-in-law of the
The full name of Donald is Donald J. Trump Jr., the president's son-in-law
This behavior is undesirable and should be like normal beam_search, which extracts multiple sentences from the beam candidates.
Motivation
In the current implementation, num_return_sequences has no reason to exist, since the almost same result can be obtained by executing N times with num_return_sequences=1. What the two codes below do is much the same.
from transformers import AutoTokenizer, AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
inputs = tokenizer(["The full name of Donald is Donald"], return_tensors="pt")
outputs = model.generate(**inputs, num_beams=5, num_return_sequences=5,do_sample=True)
print("\n".join(tokenizer.batch_decode(outputs, skip_special_tokens=True)))from transformers import AutoTokenizer, AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
inputs = tokenizer(["The full name of Donald is Donald"]*5, return_tensors="pt")
outputs = model.generate(**inputs, num_beams=5, num_return_sequences=1,do_sample=True)
print("\n".join(tokenizer.batch_decode(outputs, skip_special_tokens=True)))If we set num_return_sequences>1, we want to all outputs that are guaranteed to be different, so it is preferable to behave like a normal beam_search.
Please let me know if there is a reason for the current implementation.
Your contribution
It's only about 3 lines to correct and I can do it if you need.