Skip to content

[BigModeling] Add missing check for quantized models#1652

Merged
younesbelkada merged 4 commits intomainfrom
fix-to-int8
Jun 28, 2023
Merged

[BigModeling] Add missing check for quantized models#1652
younesbelkada merged 4 commits intomainfrom
fix-to-int8

Conversation

@younesbelkada
Copy link
Contributor

What does this PR do?

Fixes huggingface/transformers#24540 and the failing test: https://github.com/huggingface/accelerate/actions/runs/5396594202/jobs/9800396637

Currently on the main branch loading a quantized model on a single GPU fails:

from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import torch

model_path="facebook/opt-350m"

config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, load_in_8bit=True, device_map="auto")

tokenizer = AutoTokenizer.from_pretrained(model_path)

input_text = "Describe the solar system."
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")

outputs = model.generate(input_ids, max_length=100)
print(tokenizer.decode(outputs[0]))

In #1648 it seems that a check was missing before calling .to to the model in case of single GPU model dispatching. Adding a small check circunvemts this issue

cc @sgugger @SunMarc

else:
device = list(device_map.values())[0]
if device != "disk":
if device != "disk" and not getattr(model, "is_quantized", False):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should maybe check with the attributes is_loaded_in_8bit or is_loaded_in_4bit as is_quantized has been only recently introduced. WDYT?

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 28, 2023

The documentation is not available anymore as the PR was closed or merged.

@younesbelkada younesbelkada requested review from SunMarc and sgugger June 28, 2023 06:46
device = list(device_map.values())[0]
if device != "disk":
# for backward compatibility
is_quantized = getattr(model, "is_quantized", False) or getattr(model, "is_loaded_in_8bit", False)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getattr(model, "is_loaded_in_8bit", False) --> for backward compatilibty for users that have an old version of transformers (before 4bit integration)

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your PR, added a small comment.

Comment on lines 394 to 396
elif is_quantized:
pass
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
elif is_quantized:
pass
else:
elif not is_quantized:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah added it already b93d93c before your review :D

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

4 participants