Generate: remove most decoder-only LLMs prepare_inputs_for_generation#33870
Generate: remove most decoder-only LLMs prepare_inputs_for_generation#33870gante merged 15 commits intohuggingface:mainfrom
prepare_inputs_for_generation#33870Conversation
|
Hey! 🤗 Thanks for your contribution to the Before merging this pull request, slow tests CI should be triggered. To enable this:
(For maintainers) The documentation for slow tests CI on PRs is here. |
src/transformers/generation/utils.py
Outdated
There was a problem hiding this comment.
Not all models expect this one. We now inspect the signature to determine whether we need to generate them on the fly
src/transformers/generation/utils.py
Outdated
There was a problem hiding this comment.
these are moved to kwargs. We now forward kwargs to the model inputs :)
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
akshit397a
left a comment
There was a problem hiding this comment.
This is up to mark working efficiently
zucchini-nlp
left a comment
There was a problem hiding this comment.
Wow, so much code killed, thanks!
There was a problem hiding this comment.
Just curious: does that mean blenderbot cannot generate from inputs embeds and it cannot be fixed? I see many models touched here didn't pass further inputs embeds, so that mean after this PR all of them will support generation from embeddings. So interesting to see why this model failed
There was a problem hiding this comment.
I see many models touched here didn't pass further inputs embeds, so that mean after this PR all of them will support generation from embeddings.
Precisely! Many models will get this feature for free as part of these deletions 💛
Just curious: does that mean blenderbot cannot generate from inputs embeds and it cannot be fixed?
No clue, I didn't dive deeper :) Failed in inputs_embeds tests -> pasted this comment. I don't think these combos of model/feature are worth the dive, so I left this low-information (but better than nothing) note
There was a problem hiding this comment.
Actually the test was just flaky! I've added flakiness protection to the failing test and deleted a few more cases :)
There was a problem hiding this comment.
I think this was marked flaky for VLMs in one of the other PRs
There was a problem hiding this comment.
With this PR, it becomes a failure all the times 👀 I have no idea why (didn't dive)
There was a problem hiding this comment.
super sad, i started diving a while ago and that seems related to paligemma's weird masking for prefix/suffix. I'll see if I can get time to spot the bug
tests/test_modeling_common.py
Outdated
There was a problem hiding this comment.
(this test calls generate)
| if ( | ||
| attention_mask is not None | ||
| and kwargs.get("position_ids") is None | ||
| and "position_ids" in set(inspect.signature(self.forward).parameters.keys()) |
There was a problem hiding this comment.
quick Q, how fast is this / is it slowing down generation?
- we can store the inspect result if needed otherwise!
There was a problem hiding this comment.
It's not too bad, but can be improved, yes. On my machine, this adds 0.024ms per generated token (small, but not negligible). If we cache the inspect.signature, we reduce it by 100x.
We actually make several inspect.signature(foward) calls in generate and other bits of the codebase, I think it makes sense to store the inspect as a cached model property (e.g. model.forward_signature). WDYT? If you agree, I'll open a follow-up PR with this change
For completeness, script to measure the impact of caching this call:
import time
import inspect
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("distilgpt2")
# Fresh inspect
all_times = []
for _ in range(1000):
start = time.time()
"position_ids" in set(inspect.signature(model.forward).parameters.keys())
all_times.append(time.time() - start)
print(sum(all_times) / len(all_times))
# Cached inspect
signature_keys = set(inspect.signature(model.forward).parameters.keys())
all_times = []
for _ in range(1000):
start = time.time()
"position_ids" in signature_keys
all_times.append(time.time() - start)
print(sum(all_times) / len(all_times))| model_inputs["input_ids"] = input_ids.clone(memory_format=torch.contiguous_format) | ||
| model_inputs["inputs_embeds"] = None | ||
|
|
||
| # 4. Create missing `position_ids` on the fly |
| ): | ||
| position_ids = attention_mask.long().cumsum(-1) - 1 | ||
| position_ids.masked_fill_(attention_mask == 0, 1) | ||
| kwargs["position_ids"] = position_ids # placed in kwargs for further processing (see below) |
There was a problem hiding this comment.
seen in other PRs, that it needed to be sliced to seq_length no? -seq_len:
There was a problem hiding this comment.
Yes, slicing happens in the code block after this one. That code block abstracts slicing to other input names (e.g. token_type_ids needs to be sliced exactly like position_ids -- and we can add more to this list as needed 🤗 )
| for key, value in kwargs.items(): | ||
| if key not in model_inputs: | ||
| model_inputs[key] = value |
There was a problem hiding this comment.
not sure this is super efficient TBH!
There was a problem hiding this comment.
Its run time is negligible, even if kwargs contains a handful of entries (usually it will only contain one or two). At most 0.001 ms per call :P
On the plus side, this code block will allow us to generalize this function to VLMs 😉 I think that's worth the super small cost.
import time
import torch
all_times = []
for _ in range(1000):
model_inputs = {str(i): i for i in range(10)}
kwargs = {'a': 1, 'b': 2, 'c': torch.zeros((100, 100)), "0": 12, "1": 3546}
start = time.time()
for key, value in kwargs.items():
if key not in model_inputs:
model_inputs[key] = value
all_times.append(time.time() - start)
print(sum(all_times) / len(all_times))
ArthurZucker
left a comment
There was a problem hiding this comment.
Okay good for me, let's fix generate tests if related
|
before merging, ran locally:
|
The error in PEFT is occurring after this transformers change: huggingface/transformers#33870 Now, in our tests, some model_kwargs no longer necessarily contain past_key_values, resulting in a KeyError. We now account for this possibility. Affected models were opt and gpt2.
The error in PEFT is occurring after this transformers change: huggingface/transformers#33870 Now, in our tests, some model_kwargs no longer necessarily contain past_key_values, resulting in a KeyError. We now account for this possibility. Affected models were opt and gpt2.
The error in PEFT is occurring after this transformers change: huggingface/transformers#33870 Now, in our tests, some model_kwargs no longer necessarily contain past_key_values, resulting in a KeyError. We now account for this possibility. Affected models were opt and gpt2.
The error in PEFT is occurring after this transformers change: huggingface/transformers#33870 Now, in our tests, some model_kwargs no longer necessarily contain past_key_values, resulting in a KeyError. We now account for this possibility. Affected models were opt and gpt2.
The error in PEFT is occurring after this transformers change: huggingface/transformers#33870 Now, in our tests, some model_kwargs no longer necessarily contain past_key_values, resulting in a KeyError. We now account for this possibility. Affected models were opt and gpt2.
What does this PR do?
Part of step 6 in #32685
Follow-up to #33677
This PR:
GenerationMixin.prepare_inputs_for_generationso as to handle models WITHOUT theCacherefactor, preparetoken_type_ids, and forward arbitrary kwargs✅ slow tests were ran on
llamaandgpt2