Skip to content

Generate: move prepare_inputs_for_generation in encoder-decoder llms#34048

Merged
gante merged 6 commits intohuggingface:mainfrom
gante:encoder_decoder_prepare
Oct 11, 2024
Merged

Generate: move prepare_inputs_for_generation in encoder-decoder llms#34048
gante merged 6 commits intohuggingface:mainfrom
gante:encoder_decoder_prepare

Conversation

@gante
Copy link
Contributor

@gante gante commented Oct 9, 2024

What does this PR do?

Part of step 6 in #32685
Follow-up to #33870

This PR:

  • Adds a minor change to GenerationMixin.prepare_inputs_for_generation to use decoder_input_ids in encoder-decoder models
  • Deletes almost all prepare_inputs_for_generation in encoder-decoder llms 🔪 😎

@gante
Copy link
Contributor Author

gante commented Oct 9, 2024

@zucchini-nlp this PR may have a conflict with your encoder-decoder+compile PR 👀

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I will update my PR when this one gets merged. Left a tiny question about Blip-2, overall LGTM as long as the tests don't complain

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it okay we're losing this? Seems like BLIP was forcefully passing this kwarg for later setting the cache?

O think we don't have tests for BlipText, neither for VLM part so we can't rely on tests for BLIP 😭 (I'll work on it soon, rn I'm working on Idefics models and BLIP will be next)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uhmm perhaps -- is_decoder=True is the default everywhere (in forward, in the config), but the user could force it to False. Going to revert

(I suspect this class is never used with is_decoder=True, but too late to fix that :D )

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, blip is a difficult case, better keep it overriden hehe

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧼 🧼 🧼 🧼 Very nice!

@gante gante force-pushed the encoder_decoder_prepare branch 2 times, most recently from ca46d3b to 40d6c34 Compare October 11, 2024 12:16
@gante gante force-pushed the encoder_decoder_prepare branch from 40d6c34 to 369b614 Compare October 11, 2024 13:44
@gante
Copy link
Contributor Author

gante commented Oct 11, 2024

Ran the following slow tests before merging:

  • Llama
  • BART
  • T5 (same failures as main)

@gante gante merged commit 37ac078 into huggingface:main Oct 11, 2024
@gante gante deleted the encoder_decoder_prepare branch October 11, 2024 15:11
BenjaminBossan added a commit to BenjaminBossan/peft that referenced this pull request Oct 14, 2024
Don't assume that past_key_values is part of the model_kwargs.

This fix is similar to huggingface#2140 but for encoder-decoder models. It became
necessary after huggingface/transformers#34048
was merged into transformers.
BenjaminBossan added a commit to huggingface/peft that referenced this pull request Oct 14, 2024
Don't assume that past_key_values is part of the model_kwargs.

This fix is similar to #2140 but for encoder-decoder models. It became
necessary after huggingface/transformers#34048
was merged into transformers.
yaswanth19 pushed a commit to yaswanth19/peft that referenced this pull request Oct 20, 2024
Don't assume that past_key_values is part of the model_kwargs.

This fix is similar to huggingface#2140 but for encoder-decoder models. It became
necessary after huggingface/transformers#34048
was merged into transformers.
yaswanth19 pushed a commit to yaswanth19/peft that referenced this pull request Oct 20, 2024
Don't assume that past_key_values is part of the model_kwargs.

This fix is similar to huggingface#2140 but for encoder-decoder models. It became
necessary after huggingface/transformers#34048
was merged into transformers.
BenjaminBossan added a commit to BenjaminBossan/peft that referenced this pull request Oct 22, 2024
Don't assume that past_key_values is part of the model_kwargs.

This fix is similar to huggingface#2140 but for encoder-decoder models. It became
necessary after huggingface/transformers#34048
was merged into transformers.
Copy link

@SabaPivot SabaPivot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please reply?
Much appreciated!

Comment on lines +3844 to +3875
def test_prepare_inputs_for_generation_encoder_decoder_llm(self):
"""
Same as `test_prepare_inputs_for_generation_decoder_llm` but for encoder-decoder models. Main difference: we
should look for `decoder_input_ids`, instead of `input_ids`.
"""
model = AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5")
model = model.to(torch_device)

# 1. Sanity check: the model's `prepare_inputs_for_generation` comes from `GenerationMixin`
self.assertTrue("GenerationMixin" in str(model.prepare_inputs_for_generation))

# 2. If we pass input ids by themselves, we should get back the same input ids -- with the encoder-decoder key
decoder_input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]]).to(torch_device)
model_inputs = model.prepare_inputs_for_generation(decoder_input_ids)
self.assertTrue(torch.all(model_inputs["decoder_input_ids"] == decoder_input_ids))

# 3. If we pass the attention mask too, we will get back the attention mask. Encoder-decoder models usually
# don't use `position_ids`
decoder_attention_mask = torch.tensor([[1, 1, 1], [1, 1, 1]]).to(torch_device)
model_inputs = model.prepare_inputs_for_generation(
decoder_input_ids, decoder_attention_mask=decoder_attention_mask
)
self.assertTrue(torch.all(model_inputs["decoder_attention_mask"] == decoder_attention_mask))
self.assertTrue("position_ids" not in model_inputs)

# 4. `use_cache` (and other kwargs, like the encoder outputs) are forwarded
self.assertFalse("use_cache" in model_inputs) # From the previous input, there is no `use_cache`
model_inputs = model.prepare_inputs_for_generation(decoder_input_ids, use_cache=True, encoder_outputs="foo")
self.assertTrue(model_inputs["use_cache"] is True)
self.assertTrue(model_inputs["encoder_outputs"] == "foo")
# See the decoder-only test for more corner cases. The code is the same, so we don't repeat it here.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I add this to my
AutoAdapterModel
to generate in adapters using T5?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you mean the tests, you should not need to add it anywhere as it is ran only to test the correctness of new modifications.

In general it is advised to post question in the forum if it is not a bug or feature request

BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
Guy-Bilitski pushed a commit to Guy-Bilitski/peft that referenced this pull request May 13, 2025
Don't assume that past_key_values is part of the model_kwargs.

This fix is similar to huggingface#2140 but for encoder-decoder models. It became
necessary after huggingface/transformers#34048
was merged into transformers.
Guy-Bilitski pushed a commit to Guy-Bilitski/UIOrthoLoRA that referenced this pull request Feb 5, 2026
Don't assume that past_key_values is part of the model_kwargs.

This fix is similar to #2140 but for encoder-decoder models. It became
necessary after huggingface/transformers#34048
was merged into transformers.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants