Skip to content

Generate: visit non-llm prepare_inputs_for_generation#34199

Merged
gante merged 5 commits intohuggingface:mainfrom
gante:last_prepare
Oct 17, 2024
Merged

Generate: visit non-llm prepare_inputs_for_generation#34199
gante merged 5 commits intohuggingface:mainfrom
gante:last_prepare

Conversation

@gante
Copy link
Contributor

@gante gante commented Oct 16, 2024

What does this PR do?

Closes #32685 🙌

This PR does a final pass over the remaining prepare_inputs_for_generation:

  • makes a few adjustments to the general function to handle trivial corner cases
  • removes the cases where the general function is equivalent
  • adds a comment on the functions that can't be removed, so we can quickly a) remember that there is a general function b) why the general function doesn't work on the model

👉 After this PR, let's aim at overwriting prepare_inputs_for_generate as few times as possible, so we can quickly roll out model-agnostic upgrades 🏎️ and minimize bugs 🐛

@gante
Copy link
Contributor Author

gante commented Oct 16, 2024

@zucchini-nlp / @ylacombe : i've tagged you both so that yoach can double-check audio models, and raushan the others 🤗

device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
Copy link
Contributor Author

@gante gante Oct 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some models have extra kwargs. With **kwargs we can make the generalization in GenerationMixin.prepare_inputs_for_generate :)

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧼 super clean!

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great clean up!

Comment on lines -447 to +452
dtype=self.get_output_embeddings().weight.dtype,
dtype=self.dtype,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for my own understanding, is this needed for multi-gpu setting?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is needed because moshi doesn't have get_output_embeddings(), and creating get_output_embeddings() there would be a bit ambiguous.

self.dtype works just as fine and is more versatile :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh dtype, not device 😄

Comment on lines +1668 to +1670
# Overwritten -- there are mutually exclusive inputs (if the logic to make `image_hidden_states` take
# precedence is moved to the model, we can remove this fn)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think most VLMs do smth similar when we pass pixel values only in pre-fill stage. I am thinking that for Idefics it can also be a check on if cache_position[0] == 0 as we don't support multi-turn dialogues. So I am think we can find a way to generalize for VLMs in a subsequent PR :)

About moving the logic to modeling, I think we want to discourage anyone to pass both

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Next I will be working on separating prefil from non-prefil.

Perhaps if I add a flag prefil: bool, we can sort most VLMs!

Copy link
Contributor

@ylacombe ylacombe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for this @gante!
Most of the audio models' prepare_inputs_for_generation were untouched, but for the ones that were, it looks like it'll work.

Special thanks for Moshi, the new prepare_inputs_for_generation covers every small edge case (no base model, no get_output_embeddings, doing post-processing). It's much cleaner to read now.

Are you planning to run slow tests for every models btw ?

@gante
Copy link
Contributor Author

gante commented Oct 17, 2024

Are you planning to run slow tests for every models btw ?

Yes :) Now that's approved, I will be running slow tests before merging @ylacombe

gante and others added 2 commits October 17, 2024 14:13
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
@gante
Copy link
Contributor Author

gante commented Oct 17, 2024

Ran slow tests on all models with significant changes, no regressions -- merging :)

@gante gante merged commit f51ac9e into huggingface:main Oct 17, 2024
@gante gante deleted the last_prepare branch October 17, 2024 15:53
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
…34199)

* tmp

* all visited

* test all

* Update src/transformers/models/moshi/modeling_moshi.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* delete another one :D

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

tracker: move prepare_inputs_for_generation into the generation mixin 🧹

4 participants