Skip to content

Generate tests: modality-agnostic input preparation#33685

Merged
gante merged 16 commits intohuggingface:mainfrom
gante:get_input_ids_and_config
Oct 3, 2024
Merged

Generate tests: modality-agnostic input preparation#33685
gante merged 16 commits intohuggingface:mainfrom
gante:get_input_ids_and_config

Conversation

@gante
Copy link
Contributor

@gante gante commented Sep 24, 2024

What does this PR do?

Requirement for #33212
Follow-up to #33663

This PR rewrites the LLM-centric _get_input_ids_and_config(), the function that creates random model inputs for tests, into a modality-agnostic prepare_config_and_inputs_for_generate()

The rest of the diff consists in propagating the change. Highlights:

  1. most _get_input_ids_and_config() overwrites were deleted as a result of the changes 🔪
  2. most test generate calls receive a dictionary of input, as opposed to input_ids
  3. _check_outputs now receives the model's main input, as opposed to input_ids
  4. because of the changes above, a few test overwrites that needed to be updated could be deleted instead 🙏

In a follow-up PR: hunt tests that no longer need to be skipped/overwritten as a result of these changes 🎯

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines +2015 to +2017
# in some models we subsample the sequence length in inner layers
if hasattr(self.model_tester, "get_subsampled_output_lengths"):
seq_length = self.model_tester.get_subsampled_output_lengths(seq_length)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some models were overwritting _check_outputs to apply subsampling on the sequence length. Since it was a single pattern in common with the overwrites, and overloading couldn't be applied here (it changes the internals), I've decided to move the pattern here

input_mask = None
if self.use_input_mask:
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
Copy link
Contributor Author

@gante gante Sep 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

common pattern: input_mask, which was then passed around as attention_mask, was a torch.float32 instead of a torch.long 👀

pad_token_id=99,
bos_token_id=99,
num_codebooks=4,
audio_channels=1,
Copy link
Contributor Author

@gante gante Sep 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

audio_channels=1 is the default in the config. In other words, this doesn't change the tests, but allows us to quickly override the config when needed (see overloaded test below)

lm_heads = model.get_output_embeddings()
self.assertTrue(lm_heads is None or isinstance(lm_heads[0], torch.nn.Linear))

def _get_input_ids_and_config(self, batch_size=2):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(all this deleted code corresponds to overwritten functions that no longer need to be overwritten)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

love it!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's very nice

@gante gante marked this pull request as ready for review September 25, 2024 14:34
Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! So much code cleaned up, thanks! 💓

Overall looks good to me, just a few question for my general understanding. I see some VLMs are failing the CI. I remember skipping one of the beam search tests for VLMs earlier so it's prob that. But lmk if you want me to look at it :)

for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[self.input_name]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering what happens if we use model_cls.main_input_name? AFAIR from few months ago, there were some inconsistencies in how model main input is defined, and we could make another round of cleaning up on that because main_input_name is also used in generate(). Maybe we can have a more generalized interface and testing suite?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree we should use the model's main_input_name. I will make the change and see what breaks 🤞

If all tests pass, I'll update it. Otherwise I'll add a TODO for us :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not only it worked, but also it allowed us to remove the input_name attribute from all testers 💛

lm_heads = model.get_output_embeddings()
self.assertTrue(lm_heads is None or isinstance(lm_heads[0], torch.nn.Linear))

def _get_input_ids_and_config(self, batch_size=2):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

love it!

config.forced_eos_token_id = None
return config, input_ids, attention_mask, inputs_dict
original_sequence_length = self.model_tester.seq_length
self.model_tester.seq_length = 16
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice to see a cleaner way

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very welcome and clean changes! Nice to see less tests being overwritten 👌

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very welcome.

IMO we should have:

  • a dict of common forward input names, with default values, that range from 0 to the small model vocab size.
  • a variant for vLMs
  • a variant for image only
  • a variant for audio
    And then you just take them using inspect to inspect the forward pass.

Now for testing the padding and else, of course you need somethibng else, but this way people don't have to do this ever again:

def prepare_image():
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
return raw_image
def prepare_dog_img():
img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/dog-sam.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
return raw_image

@gante
Copy link
Contributor Author

gante commented Oct 3, 2024

@ArthurZucker that is a cool idea, model-agnostic inputs, potentially dependent on modality (or perhaps simply looking at the signature of the fwd pass?)

I took note of it to explore after the current round of reactors, to avoid adding more parallel threads :D

@gante gante merged commit d29738f into huggingface:main Oct 3, 2024
@gante gante deleted the get_input_ids_and_config branch October 3, 2024 13:01
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants