Skip to content

InstructBLIP - FlanT5-XL model Int4/8 quantization broken #24884

@lukealexmiller

Description

@lukealexmiller

System Info

  • transformers version: 4.32.0.dev0
  • Platform: Linux-4.14.314-238.539.amzn2.x86_64-x86_64-with-glibc2.31
  • Python version: 3.10.9
  • Huggingface_hub version: 0.16.4
  • Safetensors version: 0.3.1
  • Accelerate version: 0.22.0.dev0
  • Accelerate config: - compute_environment: LOCAL_MACHINE
    • distributed_type: MULTI_GPU
    • mixed_precision: no
    • use_cpu: False
    • num_processes: 4
    • machine_rank: 0
    • num_machines: 1
    • gpu_ids: all
    • rdzv_backend: static
    • same_network: True
    • main_training_function: main
    • downcast_bf16: no
    • tpu_use_cluster: False
    • tpu_use_sudo: False
  • PyTorch version (GPU?): 1.13.1+cu117 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No

Who can help?

@ArthurZucker @younesbelkada @NielsRogge

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Problem

Specifying load_in_8bit or load_in_4bit for Salesforce/instructblip-flan-t5-xl, I am able to load the model into GPU memory, but calling generate results in an error.

Steps to Reproduce:

torch.bfloat16 Working Version:

  1. Load model into memory
from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
import torch
from PIL import Image
import requests

device = "cuda" if torch.cuda.is_available() else "cpu"

MODEL_NAME = "Salesforce/instructblip-flan-t5-xl"
# load in bfloat16 - this is type t5 models were pretrained using (see https://github.com/salesforce/LAVIS/issues/418)
model = InstructBlipForConditionalGeneration.from_pretrained(MODEL_NAME, device_map="auto", torch_dtype=torch.bfloat16)

processor = InstructBlipProcessor.from_pretrained(MODEL_NAME)
  1. Run example VQA
url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/docs/_static/Confusing-Pictures.jpg"
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
prompt = "What is unusual about this image?"

# Cast to torch.bfloat16, otherwise we get an error.
inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.bfloat16)

outputs = model.generate(
    **inputs,
    do_sample=False,
    num_beams=5,
    max_length=256,
    min_length=1,
    top_p=0.9,
    repetition_penalty=1.5,
    length_penalty=1.0,
    temperature=1,
)

generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
print(generated_text)
  1. Observe generated text: The image depicts a man ironing clothes on the back of a yellow van in the middle of a busy city street. The unusual aspect of the image is that the man is not wearing a shirt, which may indicate that he is a homeless person or an immigrant. In addition, there are several other vehicles in the background, including taxis, buses, and motorcycles.

load_in_8bit Failing Version:

  1. Load model into memory
from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
import torch
from PIL import Image
import requests

device = "cuda" if torch.cuda.is_available() else "cpu"

MODEL_NAME = "Salesforce/instructblip-flan-t5-xl"
# Note: Here we no longer specify `torch.bfloat16`.
model = InstructBlipForConditionalGeneration.from_pretrained(MODEL_NAME, device_map="auto", load_in_8bit=True)

processor = InstructBlipProcessor.from_pretrained(MODEL_NAME)
  1. Run example VQA. Note we use the same input type as in the test code.
url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/docs/_static/Confusing-Pictures.jpg"
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
prompt = "What is unusual about this image?"

# Note: Here we no longer specify `torch.bfloat16`, but we use `torch.float16` as shown in the test code for Salesforce/instructlblup-vicuna-7b
inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16)

outputs = model.generate(
    **inputs,
    do_sample=False,
    num_beams=5,
    max_length=256,
    min_length=1,
    top_p=0.9,
    repetition_penalty=1.5,
    length_penalty=1.0,
    temperature=1,
)

generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
print(generated_text)
  1. Observe error
RuntimeError                              Traceback (most recent call last)
Cell In[4], line 14
     11         if torch.is_floating_point(v):
     12             inputs[k] = v.to(torch.float16)
---> 14 outputs = model.generate(
     15     **inputs,
     16     do_sample=False,
     17     num_beams=5,
     18     max_length=256,
     19     min_length=1,
     20     top_p=0.9,
     21     repetition_penalty=1.5,
     22     length_penalty=1.0,
     23     temperature=1,
     24 )
     25 generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
     26 print(generated_text)

File /usr/lib/python3/dist-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     24 @functools.wraps(func)
     25 def decorate_context(*args, **kwargs):
     26     with self.clone():
---> 27         return func(*args, **kwargs)

File /usr/lib/python3/dist-packages/transformers/models/instructblip/modeling_instructblip.py:1522, in InstructBlipForConditionalGeneration.generate(self, pixel_values, qformer_input_ids, qformer_attention_mask, input_ids, attention_mask, **generate_kwargs)
   1520     qformer_attention_mask = torch.ones_like(qformer_input_ids)
   1521 qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
-> 1522 query_outputs = self.qformer(
   1523     input_ids=qformer_input_ids,
   1524     attention_mask=qformer_attention_mask,
   1525     query_embeds=query_tokens,
   1526     encoder_hidden_states=image_embeds,
   1527     encoder_attention_mask=image_attention_mask,
   1528     return_dict=True,
   1529 )
   1530 query_output = query_outputs.last_hidden_state[:, : query_tokens.size(1), :]
   1532 language_model_inputs = self.language_projection(query_output)

File /usr/lib/python3/dist-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File /usr/lib/python3/dist-packages/accelerate/hooks.py:165, in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
    163         output = old_forward(*args, **kwargs)
    164 else:
--> 165     output = old_forward(*args, **kwargs)
    166 return module._hf_hook.post_forward(module, output)

File /usr/lib/python3/dist-packages/transformers/models/instructblip/modeling_instructblip.py:1169, in InstructBlipQFormerModel.forward(self, input_ids, attention_mask, position_ids, query_embeds, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)
   1163 past_key_values_length = (
   1164     past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0
   1165 )
   1167 query_length = query_embeds.shape[1] if query_embeds is not None else 0
-> 1169 embedding_output = self.embeddings(
   1170     input_ids=input_ids,
   1171     position_ids=position_ids,
   1172     query_embeds=query_embeds,
   1173     past_key_values_length=past_key_values_length,
   1174 )
   1176 input_shape = embedding_output.size()[:-1]
   1177 batch_size, seq_length = input_shape

File /usr/lib/python3/dist-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File /usr/lib/python3/dist-packages/accelerate/hooks.py:165, in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
    163         output = old_forward(*args, **kwargs)
    164 else:
--> 165     output = old_forward(*args, **kwargs)
    166 return module._hf_hook.post_forward(module, output)

File /usr/lib/python3/dist-packages/transformers/models/instructblip/modeling_instructblip.py:1041, in InstructBlipQFormerEmbeddings.forward(self, input_ids, position_ids, query_embeds, past_key_values_length)
   1038 else:
   1039     embeddings = query_embeds
-> 1041 embeddings = self.layernorm(embeddings)
   1042 embeddings = self.dropout(embeddings)
   1043 return embeddings

File /usr/lib/python3/dist-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File /usr/lib/python3/dist-packages/accelerate/hooks.py:165, in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
    163         output = old_forward(*args, **kwargs)
    164 else:
--> 165     output = old_forward(*args, **kwargs)
    166 return module._hf_hook.post_forward(module, output)

File /usr/lib/python3/dist-packages/torch/nn/modules/normalization.py:190, in LayerNorm.forward(self, input)
    189 def forward(self, input: Tensor) -> Tensor:
--> 190     return F.layer_norm(
    191         input, self.normalized_shape, self.weight, self.bias, self.eps)

File /usr/lib/python3/dist-packages/torch/nn/functional.py:2515, in layer_norm(input, normalized_shape, weight, bias, eps)
   2511 if has_torch_function_variadic(input, weight, bias):
   2512     return handle_torch_function(
   2513         layer_norm, (input, weight, bias), input, normalized_shape, weight=weight, bias=bias, eps=eps
   2514     )
-> 2515 return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)

RuntimeError: expected scalar type Float but found Half

I am unable to get load_in_8bit or load_in_4bit to work, both return these errors.

I have also tried changing the dtype casting when putting the input processing to the GPU, but observe different errors.

Expected behavior

Expect quantization to work, as it does when using Salesforce/instructblip-vicuna-7b model.

I am able to use quantized google/flan-t5-xl text generation model with the same setup, and have run pip uninstall apex as described in #21391

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions