Skip to content

Flash attention 2 broke when batch inference #34824

@pspdada

Description

@pspdada

System Info

  • transformers version: 4.46.2
  • Platform: Linux-5.15.0-120-generic-x86_64-with-glibc2.35
  • Python version: 3.10.15
  • Huggingface_hub version: 0.26.2
  • Safetensors version: 0.4.5
  • Accelerate version: 1.1.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.5.1+cu124 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: No
  • Using GPU in script?: Yes
  • GPU type: NVIDIA A100-PCIE-40GB

Who can help?

@zucchini-nlp

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Using the latest version of vllm==0.6.4.post1 and flash-attn==2.7.0.post2 using pip install flash-attn --no-build-isolation
When I use llava for batch inference, enabling flash_attention_2 makes the results particularly strange.
The code can be found in the huggingface doc https://huggingface.co/docs/transformers/main/en/model_doc/llava#batched-inference

import torch
from PIL import Image
from transformers import AutoProcessor, LlavaForConditionalGeneration

# Load the model in half-precision
model = LlavaForConditionalGeneration.from_pretrained(
    "llava-hf/llava-1.5-7b-hf",
    cache_dir="/root/llm-project/utils/models/hub",
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    attn_implementation="flash_attention_2",
    device_map="auto")
processor = AutoProcessor.from_pretrained(
    "llava-hf/llava-1.5-7b-hf",
    cache_dir="/root/llm-project/utils/models/hub",
)
image1 = Image.open("/root/llm-project/LVLM/eval/Extended_CHAIR/images/chair-500/000000006763.jpg")
image2 = Image.open("./demo/material/1.jpg")

# Prepare a batch of two prompts
conversation_1 = [
    {
        "role": "user",
        "content": [
            {
                "type": "image"
            },
            {
                "type": "text",
                "text": "What is shown in this image? Please tell me.",
            },
        ],
    },
]

conversation_2 = [
    {
        "role": "user",
        "content": [
            {
                "type": "image"
            },
            {
                "type": "text",
                "text": "Describe this image."
            },
        ],
    },
]

prompt_1 = processor.apply_chat_template(conversation_1, add_generation_prompt=True)
prompt_2 = processor.apply_chat_template(conversation_2, add_generation_prompt=True)
prompts = [prompt_1, prompt_2]

# We can simply feed images in the order they have to be used in the text prompt
inputs = processor(
    images=[image1, image2], text=prompts, padding=True, return_tensors="pt").to(model.device, torch.float16)

# Generate
generate_ids = model.generate(**inputs, max_new_tokens=200)
out = processor.batch_decode(generate_ids, skip_special_tokens=True)
print(out)

The out put is:

['USER:  \nWhat is shown in this image? Please tell me. ASSISTANT: The image shows a man and a woman standing close to each other, posing for a picture. The man is wearing a tie, and they are both smiling for the camera. The scene takes place in a restaurant, as there are dining tables and chairs visible in the background.', 
'USER:  \nDescribe this image. ASSISTANT: The, the image, the image, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the, the,']

Remove attn_implementation="flash_attention_2", everything works well.

['USER:  \nWhat is shown in this image? Please tell me. ASSISTANT: The image shows a man and a woman standing close to each other, posing for a picture. The man is wearing a tie, and they are both smiling for the camera. The scene takes place in a restaurant, as there are dining tables and chairs visible in the background.', 
'USER:  \nDescribe this image. ASSISTANT: The image features a group of birds perched on a tree branch. There are five birds in total, with some sitting closer to the front of the branch and others further back. The birds are of various sizes and colors, creating a diverse and lively scene. The birds appear to be engaged in conversation or simply enjoying their time together on the branch.']

Adding code like padding_side="left", when init the processor don't make any difference.

Expected behavior

Fix the bug

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions