Generate tests: modality-agnostic input preparation#33685
Generate tests: modality-agnostic input preparation#33685gante merged 16 commits intohuggingface:mainfrom
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| # in some models we subsample the sequence length in inner layers | ||
| if hasattr(self.model_tester, "get_subsampled_output_lengths"): | ||
| seq_length = self.model_tester.get_subsampled_output_lengths(seq_length) |
There was a problem hiding this comment.
some models were overwritting _check_outputs to apply subsampling on the sequence length. Since it was a single pattern in common with the overwrites, and overloading couldn't be applied here (it changes the internals), I've decided to move the pattern here
| input_mask = None | ||
| if self.use_input_mask: | ||
| input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device) | ||
| input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device)) |
There was a problem hiding this comment.
common pattern: input_mask, which was then passed around as attention_mask, was a torch.float32 instead of a torch.long 👀
| pad_token_id=99, | ||
| bos_token_id=99, | ||
| num_codebooks=4, | ||
| audio_channels=1, |
There was a problem hiding this comment.
audio_channels=1 is the default in the config. In other words, this doesn't change the tests, but allows us to quickly override the config when needed (see overloaded test below)
| lm_heads = model.get_output_embeddings() | ||
| self.assertTrue(lm_heads is None or isinstance(lm_heads[0], torch.nn.Linear)) | ||
|
|
||
| def _get_input_ids_and_config(self, batch_size=2): |
There was a problem hiding this comment.
(all this deleted code corresponds to overwritten functions that no longer need to be overwritten)
zucchini-nlp
left a comment
There was a problem hiding this comment.
Awesome! So much code cleaned up, thanks! 💓
Overall looks good to me, just a few question for my general understanding. I see some VLMs are failing the CI. I remember skipping one of the beam search tests for VLMs earlier so it's prob that. But lmk if you want me to look at it :)
tests/generation/test_utils.py
Outdated
| for model_class in self.all_generative_model_classes: | ||
| config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() | ||
| config, inputs_dict = self.prepare_config_and_inputs_for_generate() | ||
| main_input = inputs_dict[self.input_name] |
There was a problem hiding this comment.
wondering what happens if we use model_cls.main_input_name? AFAIR from few months ago, there were some inconsistencies in how model main input is defined, and we could make another round of cleaning up on that because main_input_name is also used in generate(). Maybe we can have a more generalized interface and testing suite?
There was a problem hiding this comment.
I agree we should use the model's main_input_name. I will make the change and see what breaks 🤞
If all tests pass, I'll update it. Otherwise I'll add a TODO for us :)
There was a problem hiding this comment.
Not only it worked, but also it allowed us to remove the input_name attribute from all testers 💛
| lm_heads = model.get_output_embeddings() | ||
| self.assertTrue(lm_heads is None or isinstance(lm_heads[0], torch.nn.Linear)) | ||
|
|
||
| def _get_input_ids_and_config(self, batch_size=2): |
| config.forced_eos_token_id = None | ||
| return config, input_ids, attention_mask, inputs_dict | ||
| original_sequence_length = self.model_tester.seq_length | ||
| self.model_tester.seq_length = 16 |
LysandreJik
left a comment
There was a problem hiding this comment.
Very welcome and clean changes! Nice to see less tests being overwritten 👌
ArthurZucker
left a comment
There was a problem hiding this comment.
Very welcome.
IMO we should have:
- a dict of common forward input names, with default values, that range from 0 to the small model vocab size.
- a variant for vLMs
- a variant for image only
- a variant for audio
And then you just take them using inspect to inspect the forward pass.
Now for testing the padding and else, of course you need somethibng else, but this way people don't have to do this ever again:
transformers/tests/models/sam/test_modeling_sam.py
Lines 448 to 458 in 1dba608
|
@ArthurZucker that is a cool idea, model-agnostic inputs, potentially dependent on modality (or perhaps simply looking at the signature of the fwd pass?) I took note of it to explore after the current round of reactors, to avoid adding more parallel threads :D |
What does this PR do?
Requirement for #33212
Follow-up to #33663
This PR rewrites the LLM-centric
_get_input_ids_and_config(), the function that creates random model inputs for tests, into a modality-agnosticprepare_config_and_inputs_for_generate()The rest of the diff consists in propagating the change. Highlights:
_get_input_ids_and_config()overwrites were deleted as a result of the changes 🔪input_ids_check_outputsnow receives the model's main input, as opposed toinput_idsIn a follow-up PR: hunt tests that no longer need to be skipped/overwritten as a result of these changes 🎯