Skip to content

use_flash_attention_2=True for Llama2 breaks generation #26697

@markovalexander

Description

@markovalexander

System Info

  • transformers version: 4.34.0
  • Platform: Linux-5.15.0-1042-oracle-x86_64-with-glibc2.29
  • Python version: 3.8.10
  • Huggingface_hub version: 0.16.4
  • Safetensors version: 0.3.2
  • Accelerate version: 0.21.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.0.1+cu117 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: no

Who can help?

text models: @ArthurZucker and @younesbelkada

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Using flash attention 2 completely breaks generation.

image

Expected behavior

Generations match

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions