Skip to content

T5/Flan-T5 text generation with load_in_8bit=True gives error expected scalar type Float but found Half #21391

@steve-marmalade

Description

@steve-marmalade

System Info

  • transformers version: 4.27.0.dev0
  • Platform: Linux-5.10.147+-x86_64-with-glibc2.29
  • Python version: 3.8.10
  • Huggingface_hub version: 0.12.0
  • PyTorch version (GPU?): 1.14.0a0+410ce96 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: no

Who can help?

@younesbelkada

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

  1. Start a container with the latest NVIDIA PyTorch Docker Image and an A100 GPU
  2. Install the latest transformers from this github repo
  3. Run the snippet from the official example
from transformers import T5Tokenizer, T5ForConditionalGeneration

tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-base", device_map="auto", load_in_8bit=True)

input_text = "translate English to German: How old are you?"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")

outputs = model.generate(input_ids)
print(tokenizer.decode(outputs[0]))

Throws

RuntimeError                              Traceback (most recent call last)
Cell In[23], line 9
      6 input_text = "translate English to German: How old are you?"
      7 input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")
----> 9 outputs = model.generate(input_ids)
     10 print(tokenizer.decode(outputs[0]))

File /usr/local/lib/python3.8/dist-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     24 @functools.wraps(func)
     25 def decorate_context(*args, **kwargs):
     26     with self.clone():
---> 27         return func(*args, **kwargs)

File /usr/local/lib/python3.8/dist-packages/transformers/generation/utils.py:1255, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, **kwargs)
   1247         logger.warning(
   1248             "A decoder-only architecture is being used, but right-padding was detected! For correct "
   1249             "generation results, please set `padding_side='left'` when initializing the tokenizer."
   1250         )
   1252 if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
   1253     # if model is encoder decoder encoder_outputs are created
   1254     # and added to `model_kwargs`
-> 1255     model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
   1256         inputs_tensor, model_kwargs, model_input_name
   1257     )
   1259 # 5. Prepare `input_ids` which will be used for auto-regressive generation
   1260 if self.config.is_encoder_decoder:

File /usr/local/lib/python3.8/dist-packages/transformers/generation/utils.py:617, in GenerationMixin._prepare_encoder_decoder_kwargs_for_generation(self, inputs_tensor, model_kwargs, model_input_name)
    615 encoder_kwargs["return_dict"] = True
    616 encoder_kwargs[model_input_name] = inputs_tensor
--> 617 model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)
    619 return model_kwargs

File /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1423, in Module._call_impl(self, *input, **kwargs)
   1418 # If we don't have any hooks, we want to skip the rest of the logic in
   1419 # this function, and just call forward.
   1420 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1421         or _global_backward_pre_hooks or _global_backward_hooks
   1422         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1423     return forward_call(*input, **kwargs)
   1424 # Do not call functions when jit is used
   1425 full_backward_hooks, non_full_backward_hooks = [], []

File /usr/local/lib/python3.8/dist-packages/accelerate/hooks.py:158, in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
    156         output = old_forward(*args, **kwargs)
    157 else:
--> 158     output = old_forward(*args, **kwargs)
    159 return module._hf_hook.post_forward(module, output)

File /usr/local/lib/python3.8/dist-packages/transformers/models/t5/modeling_t5.py:1055, in T5Stack.forward(self, input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask, inputs_embeds, head_mask, cross_attn_head_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)
   1042     layer_outputs = checkpoint(
   1043         create_custom_forward(layer_module),
   1044         hidden_states,
   (...)
   1052         None,  # past_key_value is always None with gradient checkpointing
   1053     )
   1054 else:
-> 1055     layer_outputs = layer_module(
   1056         hidden_states,
   1057         attention_mask=extended_attention_mask,
   1058         position_bias=position_bias,
   1059         encoder_hidden_states=encoder_hidden_states,
   1060         encoder_attention_mask=encoder_extended_attention_mask,
   1061         encoder_decoder_position_bias=encoder_decoder_position_bias,
   1062         layer_head_mask=layer_head_mask,
   1063         cross_attn_layer_head_mask=cross_attn_layer_head_mask,
   1064         past_key_value=past_key_value,
   1065         use_cache=use_cache,
   1066         output_attentions=output_attentions,
   1067     )
   1069 # layer_outputs is a tuple with:
   1070 # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
   1071 if use_cache is False:

File /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1423, in Module._call_impl(self, *input, **kwargs)
   1418 # If we don't have any hooks, we want to skip the rest of the logic in
   1419 # this function, and just call forward.
   1420 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1421         or _global_backward_pre_hooks or _global_backward_hooks
   1422         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1423     return forward_call(*input, **kwargs)
   1424 # Do not call functions when jit is used
   1425 full_backward_hooks, non_full_backward_hooks = [], []

File /usr/local/lib/python3.8/dist-packages/accelerate/hooks.py:158, in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
    156         output = old_forward(*args, **kwargs)
    157 else:
--> 158     output = old_forward(*args, **kwargs)
    159 return module._hf_hook.post_forward(module, output)

File /usr/local/lib/python3.8/dist-packages/transformers/models/t5/modeling_t5.py:687, in T5Block.forward(self, hidden_states, attention_mask, position_bias, encoder_hidden_states, encoder_attention_mask, encoder_decoder_position_bias, layer_head_mask, cross_attn_layer_head_mask, past_key_value, use_cache, output_attentions, return_dict)
    684 else:
    685     self_attn_past_key_value, cross_attn_past_key_value = None, None
--> 687 self_attention_outputs = self.layer[0](
    688     hidden_states,
    689     attention_mask=attention_mask,
    690     position_bias=position_bias,
    691     layer_head_mask=layer_head_mask,
    692     past_key_value=self_attn_past_key_value,
    693     use_cache=use_cache,
    694     output_attentions=output_attentions,
    695 )
    696 hidden_states, present_key_value_state = self_attention_outputs[:2]
    697 attention_outputs = self_attention_outputs[2:]  # Keep self-attention outputs and relative position weights

File /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1423, in Module._call_impl(self, *input, **kwargs)
   1418 # If we don't have any hooks, we want to skip the rest of the logic in
   1419 # this function, and just call forward.
   1420 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1421         or _global_backward_pre_hooks or _global_backward_hooks
   1422         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1423     return forward_call(*input, **kwargs)
   1424 # Do not call functions when jit is used
   1425 full_backward_hooks, non_full_backward_hooks = [], []

File /usr/local/lib/python3.8/dist-packages/accelerate/hooks.py:158, in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
    156         output = old_forward(*args, **kwargs)
    157 else:
--> 158     output = old_forward(*args, **kwargs)
    159 return module._hf_hook.post_forward(module, output)

File /usr/local/lib/python3.8/dist-packages/transformers/models/t5/modeling_t5.py:592, in T5LayerSelfAttention.forward(self, hidden_states, attention_mask, position_bias, layer_head_mask, past_key_value, use_cache, output_attentions)
    582 def forward(
    583     self,
    584     hidden_states,
   (...)
    590     output_attentions=False,
    591 ):
--> 592     normed_hidden_states = self.layer_norm(hidden_states)
    593     attention_output = self.SelfAttention(
    594         normed_hidden_states,
    595         mask=attention_mask,
   (...)
    600         output_attentions=output_attentions,
    601     )
    602     hidden_states = hidden_states + self.dropout(attention_output[0])

File /usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py:1423, in Module._call_impl(self, *input, **kwargs)
   1418 # If we don't have any hooks, we want to skip the rest of the logic in
   1419 # this function, and just call forward.
   1420 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1421         or _global_backward_pre_hooks or _global_backward_hooks
   1422         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1423     return forward_call(*input, **kwargs)
   1424 # Do not call functions when jit is used
   1425 full_backward_hooks, non_full_backward_hooks = [], []

File /usr/local/lib/python3.8/dist-packages/accelerate/hooks.py:158, in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
    156         output = old_forward(*args, **kwargs)
    157 else:
--> 158     output = old_forward(*args, **kwargs)
    159 return module._hf_hook.post_forward(module, output)

File /usr/local/lib/python3.8/dist-packages/apex/normalization/fused_layer_norm.py:386, in FusedRMSNorm.forward(self, input)
    383     return manual_rms_norm(input, self.normalized_shape, self.weight, self.eps)
    385 if self.elementwise_affine:
--> 386     return fused_rms_norm_affine(input, self.weight, self.normalized_shape, self.eps)
    387 else:
    388     return fused_rms_norm(input, self.normalized_shape, self.eps)

File /usr/local/lib/python3.8/dist-packages/apex/normalization/fused_layer_norm.py:189, in fused_rms_norm_affine(input, weight, normalized_shape, eps)
    187 args = _cast_if_autocast_enabled(input, weight, normalized_shape, eps)
    188 with torch.cuda.amp.autocast(enabled=False):
--> 189     return FusedRMSNormAffineFunction.apply(*args)

File /usr/local/lib/python3.8/dist-packages/apex/normalization/fused_layer_norm.py:69, in FusedRMSNormAffineFunction.forward(ctx, input, weight, normalized_shape, eps)
     67 input_ = input.contiguous()
     68 weight_ = weight.contiguous()
---> 69 output, invvar = fused_layer_norm_cuda.rms_forward_affine(
     70     input_, ctx.normalized_shape, weight_, ctx.eps)
     71 ctx.save_for_backward(input_, weight_, invvar)
     72 return output

RuntimeError: expected scalar type Float but found Half

Expected behavior

The model to generate a translation of the input

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions