Skip to content

[Falcon] forward pass will fail if use_cache is automatically flipped to False  #26327

@yundai424

Description

@yundai424

System Info

  • transformers version: 4.33.2
  • Platform: Linux-5.15.111.1-rolling-lts-linkedin-x86_64-with-glibc2.17
  • Python version: 3.10.2
  • Huggingface_hub version: 0.14.1
  • Safetensors version: 0.3.1
  • Accelerate version: 0.21.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.1.0a0+gitf998869 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: 8 A100 GPUs
  • Using distributed or parallel set-up in script?: Nah just torchrun

Who can help?

@ArthurZucker and @younesbelkada

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Full-parameter fine-tune a Falcon 180B using any kind of task, using activation checkpointing (--gradient_checkpointing).

When doing model = transformers.AutoModelForCausalLM.from_pretrained(), don't set use_cache=False but leave it default.

This will result in the use_cache flag being flapped to False. But the presents used for cache is not reset to None and later on this code branch which should be exclusive for use_cache=True will be entered and then hit following error:

  File "/home/jobuser/.local/lib/python3.10/site-packages/accelerate/utils/operations.py", line 636, in forward
    return model_forward(*args, **kwargs)
  File "/home/jobuser/.local/lib/python3.10/site-packages/accelerate/utils/operations.py", line 624, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
  File "/home/jobuser/.local/lib/python3.10/site-packages/transformers/models/falcon/modeling_falcon.py", line 1031, in forward
    transformer_outputs = self.transformer(
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/jobuser/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jobuser/.local/lib/python3.10/site-packages/transformers/models/falcon/modeling_falcon.py", line 952, in forward
    presents = self._convert_cache_to_standard_format(presents, batch_size)
  File "/home/jobuser/.local/lib/python3.10/site-packages/transformers/models/falcon/modeling_falcon.py", line 705, in _convert_cache_to_standard_format
    batch_size_times_num_heads, kv_length, head_dim = past_key_value[0][0].shape
IndexError: tuple index out of range

Expected behavior

The presents tuple needs to be set to None along with use_cache=False

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions