Skip to content

[Generate] Add conditional generation for multimodal models#22424

Merged
younesbelkada merged 2 commits intohuggingface:mainfrom
younesbelkada:generate-fix-cond-generation
Mar 29, 2023
Merged

[Generate] Add conditional generation for multimodal models#22424
younesbelkada merged 2 commits intohuggingface:mainfrom
younesbelkada:generate-fix-cond-generation

Conversation

@younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Mar 28, 2023

Motivation

Some multi-modal models (precisely, image to text models) can perform better if conditional text is passed. This simply means that input_ids created by _prepare_decoder_input_ids_for_generation is concatenated with input_ids that is passed along model_kwargs.

This PR aims to add the support for this feature for VisionEncoderDecoderModel, precisely now this script should be able to run without any problem:

import torch
import requests
from PIL import Image
from transformers import ViTFeatureExtractor, AutoTokenizer, VisionEncoderDecoderModel


loc = "ydshieh/vit-gpt2-coco-en"

feature_extractor = ViTFeatureExtractor.from_pretrained(loc)
tokenizer = AutoTokenizer.from_pretrained(loc)
model = VisionEncoderDecoderModel.from_pretrained(loc)
model.eval()


def predict(image, text):
    pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
    input_ids = tokenizer(text, return_tensors="pt").input_ids

    with torch.no_grad():
        output_ids = model.generate(pixel_values, input_ids=input_ids, max_length=16, num_beams=4, return_dict_in_generate=True).sequences

    preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
    preds = [pred.strip() for pred in preds]

    return preds


# We will verify our results on an image of cute cats
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
text = "an image of"
with Image.open(requests.get(url, stream=True).raw) as image:
    preds = predict(image, text)

print(preds)
>>> ['an image of two cats sleeping on a bed']

cc @gante

Related: #22423

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 28, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but a question for potential simplification! :D

Comment on lines +1292 to +1294
# conditional generation for multi-modal models.
if "input_ids" in model_kwargs and model_input_name == "pixel_values":
input_ids = torch.cat([input_ids, model_kwargs.pop("input_ids")], dim=-1)
Copy link
Contributor

@gante gante Mar 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uhmmm this seems to be the same logic as below (input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")), which is applied on encoder-decoder models.

Perhaps instead of adding these lines, we can remove the else: below?

EDIT: missed that the line below depends in inputs_tensor, which makes fusing hard. Don't worry about it :)

@younesbelkada younesbelkada requested a review from sgugger March 29, 2023 11:21
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

@younesbelkada
Copy link
Contributor Author

As this is slightly experimental, I ran blip slow tests that also includes conditional generation tests and they all pass, will merge!

@younesbelkada younesbelkada merged commit 8252e24 into huggingface:main Mar 29, 2023
@younesbelkada younesbelkada deleted the generate-fix-cond-generation branch March 29, 2023 13:35
raghavanone pushed a commit to raghavanone/transformers that referenced this pull request Apr 5, 2023
novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
@cramraj8
Copy link

cramraj8 commented Jul 11, 2023

Hi @younesbelkada , I get a following error during the training stage when providing decoder_input_ids argument. Does this modification only works for inference stage or training too ?

-> 3029 return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
ValueError: Expected input batch_size (18) to match target batch_size (512).

I used batch_size 2 and beam_size 10

For a batch of examples (during training or inference), does the input_ids has to be same shape with padding even though each example prefix can be different length ?

Does input_ids = tokenizer(text, return_tensors="pt").input_ids has to exclude special_tokens by giving the argument add_special_tokens=False

@cramraj8
Copy link

@younesbelkada @gante @sgugger When I inspected the intermediate outputs during the training, decoder_input_ids being shape [2, 8] and logits being shape [2, 8, 64002], where batch_size is 2 and prefix length is 8. Looks like decoder_outputs = self.decoder() is not predicting anything else right after prefix given decoder_input_ids

@gante
Copy link
Contributor

gante commented Jul 11, 2023

Hey @cramraj8 -- would you be able to open a new issue, containing a short self-contained script so we can reproduce it? :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants