[Generate] Add conditional generation for multimodal models#22424
[Generate] Add conditional generation for multimodal models#22424younesbelkada merged 2 commits intohuggingface:mainfrom
Generate] Add conditional generation for multimodal models#22424Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
gante
left a comment
There was a problem hiding this comment.
LGTM, but a question for potential simplification! :D
| # conditional generation for multi-modal models. | ||
| if "input_ids" in model_kwargs and model_input_name == "pixel_values": | ||
| input_ids = torch.cat([input_ids, model_kwargs.pop("input_ids")], dim=-1) |
There was a problem hiding this comment.
Uhmmm this seems to be the same logic as below (input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")), which is applied on encoder-decoder models.
Perhaps instead of adding these lines, we can remove the else: below?
EDIT: missed that the line below depends in inputs_tensor, which makes fusing hard. Don't worry about it :)
|
As this is slightly experimental, I ran blip slow tests that also includes conditional generation tests and they all pass, will merge! |
…gface#22424) * add conditional generation * add comments
…gface#22424) * add conditional generation * add comments
|
Hi @younesbelkada , I get a following error during the training stage when providing decoder_input_ids argument. Does this modification only works for inference stage or training too ? I used batch_size 2 and beam_size 10 For a batch of examples (during training or inference), does the input_ids has to be same shape with padding even though each example prefix can be different length ? Does |
|
@younesbelkada @gante @sgugger When I inspected the intermediate outputs during the training, |
|
Hey @cramraj8 -- would you be able to open a new issue, containing a short self-contained script so we can reproduce it? :) |
Motivation
Some multi-modal models (precisely, image to text models) can perform better if conditional text is passed. This simply means that
input_idscreated by_prepare_decoder_input_ids_for_generationis concatenated withinput_idsthat is passed alongmodel_kwargs.This PR aims to add the support for this feature for
VisionEncoderDecoderModel, precisely now this script should be able to run without any problem:cc @gante
Related: #22423