add weight-only int8 QAT scheme and update tests for torchao 0.15.0#3859
Conversation
Summary of ChangesHello @electroglyph, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a new weight-only int8 Quantization-Aware Training (QAT) scheme, specifically designed for models like Gemma Q8_0, by updating the underlying Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Code Review
This pull request introduces a new weight-only int8 QAT scheme and updates the tests accordingly. The changes are well-structured and correctly implement the new functionality. The test file tests/utils/test_qat.py is updated to cover the new scheme, and some existing type hint bugs have been fixed, which improves code correctness.
I have a couple of suggestions to improve maintainability:
- In
tests/utils/test_qat.py, refactoring theif/elifblock for QAT scheme configurations into a dictionary would make the code cleaner and easier to extend. - In
unsloth/models/_utils.py, thegroup_sizeof 128 is used across multiple schemes. Defining it as a constant would reduce duplication.
Overall, this is a solid contribution that extends the quantization capabilities.
| weight_only = False | ||
| if qat_scheme == "fp8-int4": | ||
| act_fq_class = Float8FakeQuantizer | ||
| weight_fq_class = Int4WeightPreshuffledFakeQuantizer | ||
| weight_fq_class = Int4WeightFakeQuantizer | ||
| min_in_features = 128 | ||
| elif qat_scheme == "fp8-fp8": | ||
| act_fq_class = Float8FakeQuantizer | ||
| weight_fq_class = Float8FakeQuantizer | ||
| min_in_features = -1 | ||
| elif qat_scheme == "int8": | ||
| act_fq_class = None | ||
| weight_fq_class = IntxFakeQuantizer | ||
| min_in_features = 128 | ||
| weight_only = True | ||
| else: | ||
| raise ValueError(f"Unknown qat_scheme: {qat_scheme}") |
There was a problem hiding this comment.
For better maintainability and readability, consider refactoring this if/elif chain into a dictionary that maps qat_scheme to its configuration. This will make it easier to add or modify schemes in the future.
SCHEME_CONFIGS = {
"fp8-int4": {
"act_fq_class": Float8FakeQuantizer,
"weight_fq_class": Int4WeightFakeQuantizer,
"min_in_features": 128,
"weight_only": False,
},
"fp8-fp8": {
"act_fq_class": Float8FakeQuantizer,
"weight_fq_class": Float8FakeQuantizer,
"min_in_features": -1,
"weight_only": False,
},
"int8": {
"act_fq_class": None,
"weight_fq_class": IntxFakeQuantizer,
"min_in_features": 128,
"weight_only": True,
},
}
config = SCHEME_CONFIGS.get(qat_scheme)
if config is None:
raise ValueError(f"Unknown qat_scheme: {qat_scheme}")
act_fq_class = config["act_fq_class"]
weight_fq_class = config["weight_fq_class"]
min_in_features = config["min_in_features"]
weight_only = config["weight_only"]|
Hey @electroglyph did you test the changes? If yes do you have any initial numbers on the VRAM usage and training speed as compared to the baseline (16bit)? |
|
i confirmed training/saving/loading/inference works. i also added an "int8" option to the test code which confirms the quantized bits are injected into the model and executed.
|
|
could't find a working version of fbgemm-gpu-genai for torch 2.9.1 from unsloth import FastModel
import torch
max_seq_length = 2048
model, tokenizer = FastModel.from_pretrained(
model_name = "unsloth/gemma-3-270m-it",
max_seq_length = max_seq_length,
load_in_4bit = False,
load_in_8bit = False,
load_in_16bit = True,
# full_finetuning = True,
)
model = FastModel.get_peft_model(
model,
r = 64,
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",],
lora_alpha = 64,
lora_dropout = 0,
bias = "none",
use_gradient_checkpointing = "unsloth",
random_state = 3407,
use_rslora = False,
loftq_config = None,
qat_scheme = "int8",
)
from unsloth.chat_templates import get_chat_template
tokenizer = get_chat_template(
tokenizer,
chat_template = "gemma3",
)
from datasets import load_dataset
dataset = load_dataset("Thytu/ChessInstruct", split = "train[:1000]")
def convert_to_chatml(example):
return {
"conversations": [
{"role": "system", "content": example["task"]},
{"role": "user", "content": example["input"]},
{"role": "assistant", "content": example["expected_output"]}
]
}
dataset = dataset.map(
convert_to_chatml
)
def formatting_prompts_func(examples):
convos = examples["conversations"]
texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False).removeprefix('<bos>') for convo in convos]
return { "text" : texts, }
dataset = dataset.map(formatting_prompts_func, batched = True)
from trl import SFTTrainer, SFTConfig
trainer = SFTTrainer(
model = model,
tokenizer = tokenizer,
train_dataset = dataset,
eval_dataset = None,
args = SFTConfig(
dataset_text_field = "text",
per_device_train_batch_size = 1,
gradient_accumulation_steps = 1,
max_steps = 100,
learning_rate = 5e-5,
logging_steps = 1,
optim = "adamw_8bit",
weight_decay = 0.001,
lr_scheduler_type = "linear",
seed = 3407,
output_dir="outputs",
report_to = "none",
),
)
from unsloth.chat_templates import train_on_responses_only
trainer = train_on_responses_only(
trainer,
instruction_part = "<start_of_turn>user\n",
response_part = "<start_of_turn>model\n",
)
trainer_stats = trainer.train()
# model.save_pretrained_merged("gemma-3-finetune", tokenizer, save_method = "merged_16bit")
# from torchao.quantization import quantize_
# from torchao.quantization.qat import QATConfig
# quantize_(model, QATConfig(step = "convert"))
# model.save_pretrained_torchao("qat", tokenizer = tokenizer)
peak_memory = torch.cuda.max_memory_allocated()
print(f"Peak CUDA memory usage: {peak_memory / 1024 / 1024:.2f} MB") |
|
Memory usage is same vs BF16 baseline you mean? |
|
i was comparing unsloth + QAT schemes to unsloth without QAT, is that what you wanted? |
|
Try a 7B model. |
|
PerGroup fails for layers not divisible by group size when using IntxWeightOnlyConfig, changed to PerAxis to fix it. i guess PerGroup for int8 isn't really a thing, oops this is the best i can do right now: i'm going to do an embeddinggemma QAT Q8_0 unquantized finetune overnight and try to see if anything is terribly wrong |
for more information, see https://pre-commit.ci
|
Something seems not right. If 4B model loaded in 8bit vs 16bit should save ~4GB worth of VRAM right... |
|
the weights stay 16 bit, but during training they are quantized to 8 bit range. on save (save_pretrained_torchao) they're actually quantized. i'm pretty sure this is okay, but i'm preparing a QAT vs non-QAT analysis to show the difference in accuracy. |
|
Oh okay then I might have misunderstood |
it's QAT, the weights are just fake quantized during training |
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
|
@electroglyph what do you mean fake quantized during training? If you mean we quantize just before matmul, then what is the advantage/point of it? |
|
|
Ok great, I apologise for my confusion thus far. This seems good and clear now. Thanks :) |
|
If you can address the other comments, then I'll be able to approve the PR :) |
okay, sorry for disappearing and thanks for your time! |
for more information, see https://pre-commit.ci
|
Yes looks good to me as well. |
…nslothai#3859) * add int8 weight-only QAT scheme, add test, fix tests for current torchao version * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change quantization to PerAxis * lambda =/ * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add torchao messages, remove group_size from int8 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * raise exception on missing torchao * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * touch up the torchao imports * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
this adds a new QAT scheme "int8" which is weight only
should be a good choice for Gemma Q8_0 QAT models
closes #3845
closes #3858