Skip to content

[Cache] rename dtype attribute 🚨 🚨 #37044

Merged
ArthurZucker merged 5 commits intohuggingface:mainfrom
gante:rm_dtype_attr_in_cache
Mar 28, 2025
Merged

[Cache] rename dtype attribute 🚨 🚨 #37044
ArthurZucker merged 5 commits intohuggingface:mainfrom
gante:rm_dtype_attr_in_cache

Conversation

@gante
Copy link
Contributor

@gante gante commented Mar 27, 2025

Fixes #36938
Fixes #36814
[fine-tunning gemma3 or other models with a non-default cache]

🚨 Breaking: renaming of a public attribute in a public class.

accelerate sensibly detects whether a given object is a tensor or tensor-like through its type or, alternatively, through the existence of a dtype attribute (example). Our StaticCache and related objects accept dtype at init time, and store it as an attribute under the same name. Because of this, accelerate may treat our caches as a tensor, leading to downstream problems as in the issues above.

Since self.dtype is only used to initialize tensors, renaming it shouldn't be too breaking 🤞


Code for reproduction
from datasets import load_dataset
from PIL import Image

# System message for the assistant
system_message = "You are an expert product description writer for Amazon."

# User prompt that combines the user query and the schema
user_prompt = """Create a Short Product description based on the provided <PRODUCT> and <CATEGORY> and image.
Only return description. The description should be SEO optimized and for a better mobile search experience.

<PRODUCT>
{product}
</PRODUCT>

<CATEGORY>
{category}
</CATEGORY>
"""

# Convert dataset to OAI messages
def format_data(sample):
  return {
      "messages": [
          {
              "role": "system",
              "content": [{"type": "text", "text": system_message}],
          },
          {
              "role": "user",
              "content": [
                  {
                      "type": "text",
                      "text": user_prompt.format(
                          product=sample["Product Name"],
                          category=sample["Category"],
                      ),
                  },
                  {
                      "type": "image",
                      "image": sample["image"],
                  },
              ],
          },
          {
              "role": "assistant",
              "content": [{"type": "text", "text": sample["description"]}],
          },
      ],
  }

def process_vision_info(messages: list[dict]) -> list[Image.Image]:
  image_inputs = []
  # Iterate through each conversation
  for msg in messages:
      # Get content (ensure it's a list)
      content = msg.get("content", [])
      if not isinstance(content, list):
          content = [content]

      # Check each content element for images
      for element in content:
          if isinstance(element, dict) and (
              "image" in element or element.get("type") == "image"
          ):
              # Get the image and convert to RGB
              if "image" in element:
                  image = element["image"]
              else:
                  image = element
              image_inputs.append(image.convert("RGB"))
  return image_inputs

# Load dataset from the hub
dataset = load_dataset("philschmid/amazon-product-descriptions-vlm", split="train")
dataset = dataset.select(range(2))

# Convert dataset to OAI messages
# need to use list comprehension to keep Pil.Image type, .mape convert image to bytes
dataset = [format_data(sample) for sample in dataset]


import torch
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig

# Hugging Face model id
model_id = "google/gemma-3-4b-pt" # or `google/gemma-3-12b-pt`, `google/gemma-3-27-pt`

# Check if GPU benefits from bfloat16
if torch.cuda.get_device_capability()[0] < 8:
  raise ValueError("GPU does not support bfloat16, please use a GPU that supports bfloat16.")

# Define model init arguments
model_kwargs = dict(
  attn_implementation="flash_attention_2", # Use "flash_attention_2" when running on Ampere or newer GPU
  torch_dtype=torch.bfloat16, # What torch dtype to use, defaults to auto
  device_map="auto", # Let torch decide how to load the model
)

# Load model and tokenizer
model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")

from transformers import TrainingArguments

args = TrainingArguments(
  num_train_epochs=1,
  remove_unused_columns=False,
  per_device_train_batch_size=1,
  per_device_eval_batch_size=1,
  bf16=True,
  output_dir="./output",
  eval_strategy="epoch",
  report_to="none",
)

# Create a data collator to encode text and image pairs
def collate_fn(examples):
  texts = []
  images = []
  for example in examples:
      image_inputs = process_vision_info(example["messages"])
      text = processor.apply_chat_template(
          example["messages"], add_generation_prompt=False, tokenize=False
      )
      texts.append(text.strip())
      images.append(image_inputs)

  # Tokenize the texts and process the images
  batch = processor(text=texts, images=images, return_tensors="pt", padding="max_length", max_length=512, truncation=True)

  # The labels are the input_ids, and we mask the padding tokens and image tokens in the loss computation
  labels = batch["input_ids"].clone()

  # Mask image tokens
  image_token_id = [
      processor.tokenizer.convert_tokens_to_ids(
          processor.tokenizer.special_tokens_map["boi_token"]
      )
  ]
  # Mask tokens for not being used in the loss computation
  labels[labels == processor.tokenizer.pad_token_id] = -100
  labels[labels == image_token_id] = -100
  labels[labels == 262144] = -100

  batch["labels"] = labels
  return batch

from transformers import Trainer

trainer = Trainer(
  model=model,
  args=args,
  train_dataset=dataset,
  eval_dataset=dataset,
  processing_class=processor,
  data_collator=collate_fn,
)

# Start training, the model will be automatically saved to the Hub and the output directory
trainer.train()

@github-actions github-actions bot marked this pull request as draft March 27, 2025 14:10
@github-actions
Copy link
Contributor

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the Ready for review button (at the bottom of the PR page). This will assign reviewers and trigger CI.

@gante gante changed the title [Cache] remove dtype attribute [Cache] remove dtype attribute 🚨 🚨 Mar 27, 2025
@gante gante marked this pull request as ready for review March 27, 2025 14:17
@gante gante requested review from zucchini-nlp and removed request for Rocketknight1 March 27, 2025 14:25
@gante gante changed the title [Cache] remove dtype attribute 🚨 🚨 [Cache] rename dtype attribute 🚨 🚨 Mar 27, 2025
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WE did not break this and its too risky for a patch! lets affect the least amount of people

@ArthurZucker ArthurZucker merged commit bab605d into huggingface:main Mar 28, 2025
19 of 20 checks passed
@gante gante removed the for patch Tag issues / labels that should be included in the next patch label Mar 28, 2025
@gante gante deleted the rm_dtype_attr_in_cache branch March 28, 2025 18:41
@mark-mkv
Copy link

mark-mkv commented Apr 2, 2025

Looks like there is an issue finetuning Gemma3 with eager attention implementation when installing transformers from the main branch

{'loss': 0.0, 'grad_norm': nan, 'learning_rate': 7.500000000000001e-06, 'mean_token_accuracy': 0.005072383000515402, 'epoch': 0.03}

It works with SDPA, though eager is what Gemma3 recommends.

@ArthurZucker
Copy link
Collaborator

Arf 😢 @SunMarc if you can have a look!

@ArthurZucker
Copy link
Collaborator

We should no longer recommend eager as flex and flash do the proper fix

@gante
Copy link
Contributor Author

gante commented Apr 2, 2025

@mark-314e I'm assuming it's not directly related to this PR, since this PR fixes a logic error that was preventing training in some circumstances. Would you be able to open a new issue with a self-contained example to reproduce it? 🤗

zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request May 14, 2025
soghomon-b pushed a commit to soghomon-b/transformers that referenced this pull request Aug 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

6 participants