Skip to content

add weight-only int8 QAT scheme and update tests for torchao 0.15.0#3859

Merged
Datta0 merged 13 commits into
unslothai:mainfrom
electroglyph:int8
Jan 16, 2026
Merged

add weight-only int8 QAT scheme and update tests for torchao 0.15.0#3859
Datta0 merged 13 commits into
unslothai:mainfrom
electroglyph:int8

Conversation

@electroglyph

Copy link
Copy Markdown
Contributor

this adds a new QAT scheme "int8" which is weight only
should be a good choice for Gemma Q8_0 QAT models

closes #3845
closes #3858

@gemini-code-assist

Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @electroglyph, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a new weight-only int8 Quantization-Aware Training (QAT) scheme, specifically designed for models like Gemma Q8_0, by updating the underlying torchao library usage and extending the QAT test suite. The changes ensure that the new quantization method is correctly applied and verified, providing a more efficient quantization option for model weights.

Highlights

  • New Weight-Only Int8 QAT Scheme: Introduced a new 'int8' Quantization-Aware Training (QAT) scheme that focuses solely on weight quantization, making it suitable for models like Gemma Q8_0.
  • TorchAO 0.15.0 Compatibility: Updated fake quantizer imports and test logic to align with changes in the torchao library version 0.15.0, specifically replacing Int4WeightPreshuffledFakeQuantizer with Int4WeightFakeQuantizer and introducing IntxFakeQuantizer.
  • Enhanced QAT Test Coverage: Extended the QAT test suite to include the new 'int8' scheme, ensuring that both full finetuning and LoRA models are correctly tested for weight-only quantization behavior, including proper handling of activation quantizers.
  • Model Preparation Integration: Integrated the 'int8' QAT scheme into the _prepare_model_for_qat utility, configuring it with IntxWeightOnlyConfig for group-wise int8 quantization on linear layers.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new weight-only int8 QAT scheme and updates the tests accordingly. The changes are well-structured and correctly implement the new functionality. The test file tests/utils/test_qat.py is updated to cover the new scheme, and some existing type hint bugs have been fixed, which improves code correctness.

I have a couple of suggestions to improve maintainability:

  • In tests/utils/test_qat.py, refactoring the if/elif block for QAT scheme configurations into a dictionary would make the code cleaner and easier to extend.
  • In unsloth/models/_utils.py, the group_size of 128 is used across multiple schemes. Defining it as a constant would reduce duplication.

Overall, this is a solid contribution that extends the quantization capabilities.

Comment thread tests/utils/test_qat.py
Comment on lines +53 to 68
weight_only = False
if qat_scheme == "fp8-int4":
act_fq_class = Float8FakeQuantizer
weight_fq_class = Int4WeightPreshuffledFakeQuantizer
weight_fq_class = Int4WeightFakeQuantizer
min_in_features = 128
elif qat_scheme == "fp8-fp8":
act_fq_class = Float8FakeQuantizer
weight_fq_class = Float8FakeQuantizer
min_in_features = -1
elif qat_scheme == "int8":
act_fq_class = None
weight_fq_class = IntxFakeQuantizer
min_in_features = 128
weight_only = True
else:
raise ValueError(f"Unknown qat_scheme: {qat_scheme}")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better maintainability and readability, consider refactoring this if/elif chain into a dictionary that maps qat_scheme to its configuration. This will make it easier to add or modify schemes in the future.

    SCHEME_CONFIGS = {
        "fp8-int4": {
            "act_fq_class": Float8FakeQuantizer,
            "weight_fq_class": Int4WeightFakeQuantizer,
            "min_in_features": 128,
            "weight_only": False,
        },
        "fp8-fp8": {
            "act_fq_class": Float8FakeQuantizer,
            "weight_fq_class": Float8FakeQuantizer,
            "min_in_features": -1,
            "weight_only": False,
        },
        "int8": {
            "act_fq_class": None,
            "weight_fq_class": IntxFakeQuantizer,
            "min_in_features": 128,
            "weight_only": True,
        },
    }
    config = SCHEME_CONFIGS.get(qat_scheme)
    if config is None:
        raise ValueError(f"Unknown qat_scheme: {qat_scheme}")

    act_fq_class = config["act_fq_class"]
    weight_fq_class = config["weight_fq_class"]
    min_in_features = config["min_in_features"]
    weight_only = config["weight_only"]

Comment thread unsloth/models/_utils.py Outdated
@Datta0

Datta0 commented Jan 7, 2026

Copy link
Copy Markdown
Collaborator

Hey @electroglyph did you test the changes? If yes do you have any initial numbers on the VRAM usage and training speed as compared to the baseline (16bit)?
Also if you can share a script or notebook to test your changes that would be great

@electroglyph

electroglyph commented Jan 7, 2026

Copy link
Copy Markdown
Contributor Author

i confirmed training/saving/loading/inference works. i also added an "int8" option to the test code which confirms the quantized bits are injected into the model and executed.

pytest tests/utils/test_qat.py to run the test. i'll check memory and share a lil example soon

@electroglyph

Copy link
Copy Markdown
Contributor Author

could't find a working version of fbgemm-gpu-genai for torch 2.9.1
tested with xformers/triton, memory usage is the same:

FFT 16 bit memory usage = 4390.90 MB
with int8               = same
with int4               = same
LoRA rank 64            = 2690.77 MB
with int8               = same
with int4               = same
from unsloth import FastModel
import torch
max_seq_length = 2048

model, tokenizer = FastModel.from_pretrained(
    model_name = "unsloth/gemma-3-270m-it",
    max_seq_length = max_seq_length,
    load_in_4bit = False,
    load_in_8bit = False,
    load_in_16bit = True,
    # full_finetuning = True,
)

model = FastModel.get_peft_model(
    model,
    r = 64,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 64,
    lora_dropout = 0,
    bias = "none",
    use_gradient_checkpointing = "unsloth",
    random_state = 3407,
    use_rslora = False,
    loftq_config = None,
    qat_scheme = "int8",
)

from unsloth.chat_templates import get_chat_template
tokenizer = get_chat_template(
    tokenizer,
    chat_template = "gemma3",
)
from datasets import load_dataset
dataset = load_dataset("Thytu/ChessInstruct", split = "train[:1000]")

def convert_to_chatml(example):
    return {
        "conversations": [
            {"role": "system", "content": example["task"]},
            {"role": "user", "content": example["input"]},
            {"role": "assistant", "content": example["expected_output"]}
        ]
    }
dataset = dataset.map(
    convert_to_chatml
)

def formatting_prompts_func(examples):
   convos = examples["conversations"]
   texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False).removeprefix('<bos>') for convo in convos]
   return { "text" : texts, }

dataset = dataset.map(formatting_prompts_func, batched = True)

from trl import SFTTrainer, SFTConfig
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    eval_dataset = None,
    args = SFTConfig(
        dataset_text_field = "text",
        per_device_train_batch_size = 1,
        gradient_accumulation_steps = 1,
        max_steps = 100,
        learning_rate = 5e-5,
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.001,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir="outputs",
        report_to = "none",
    ),
)

from unsloth.chat_templates import train_on_responses_only
trainer = train_on_responses_only(
    trainer,
    instruction_part = "<start_of_turn>user\n",
    response_part = "<start_of_turn>model\n",
)

trainer_stats = trainer.train()
# model.save_pretrained_merged("gemma-3-finetune", tokenizer, save_method = "merged_16bit")
# from torchao.quantization import quantize_
# from torchao.quantization.qat import QATConfig
# quantize_(model, QATConfig(step = "convert"))
# model.save_pretrained_torchao("qat", tokenizer = tokenizer)

peak_memory = torch.cuda.max_memory_allocated()
print(f"Peak CUDA memory usage: {peak_memory / 1024 / 1024:.2f} MB")

@Datta0

Datta0 commented Jan 7, 2026

Copy link
Copy Markdown
Collaborator

Memory usage is same vs BF16 baseline you mean?
You're using 270m model? So the difference is perhaps not noticable.
You should probably try some bigger model like 7B or something and then try measuring the difference

@electroglyph

Copy link
Copy Markdown
Contributor Author

i was comparing unsloth + QAT schemes to unsloth without QAT, is that what you wanted?

@Datta0

Datta0 commented Jan 7, 2026

Copy link
Copy Markdown
Collaborator

Try a 7B model.
Unsloth + QAT (8bit that you added in the PR) vs Unsloth 16bit

@electroglyph

electroglyph commented Jan 7, 2026

Copy link
Copy Markdown
Contributor Author

PerGroup fails for layers not divisible by group size when using IntxWeightOnlyConfig, changed to PerAxis to fix it. i guess PerGroup for int8 isn't really a thing, oops

this is the best i can do right now:

unsloth/gemma-3-4b-it:
16 bit LoRA rank 256 = 13632.72 MB
same + "int8"        = 13658.49 MB

i'm going to do an embeddinggemma QAT Q8_0 unquantized finetune overnight and try to see if anything is terribly wrong

@Datta0

Datta0 commented Jan 7, 2026

Copy link
Copy Markdown
Collaborator

Something seems not right. If 4B model loaded in 8bit vs 16bit should save ~4GB worth of VRAM right...

@electroglyph

Copy link
Copy Markdown
Contributor Author

the weights stay 16 bit, but during training they are quantized to 8 bit range. on save (save_pretrained_torchao) they're actually quantized. i'm pretty sure this is okay, but i'm preparing a QAT vs non-QAT analysis to show the difference in accuracy.

@Datta0

Datta0 commented Jan 8, 2026

Copy link
Copy Markdown
Collaborator

Oh okay then I might have misunderstood
We can't call this QAT (aka Quantization aware training) if we're not training in quantized format
This is just an export format post training (perhaps can be called PTQ or Post Training Quantization)

Comment thread tests/utils/test_qat.py Outdated
Comment thread tests/utils/test_qat.py
Comment thread unsloth/models/_utils.py Outdated
@electroglyph

Copy link
Copy Markdown
Contributor Author

Oh okay then I might have misunderstood We can't call this QAT (aka Quantization aware training) if we're not training in quantized format This is just an export format post training (perhaps can be called PTQ or Post Training Quantization)

it's QAT, the weights are just fake quantized during training

@Datta0

Datta0 commented Jan 9, 2026

Copy link
Copy Markdown
Collaborator

@electroglyph what do you mean fake quantized during training? If you mean we quantize just before matmul, then what is the advantage/point of it?

@electroglyph

Copy link
Copy Markdown
Contributor Author

@electroglyph what do you mean fake quantized during training? If you mean we quantize just before matmul, then what is the advantage/point of it?

https://github.com/electroglyph/unsloth_QAT_results

@Datta0

Datta0 commented Jan 9, 2026

Copy link
Copy Markdown
Collaborator

Ok great, I apologise for my confusion thus far. This seems good and clear now. Thanks :)

@Datta0

Datta0 commented Jan 9, 2026

Copy link
Copy Markdown
Collaborator

If you can address the other comments, then I'll be able to approve the PR :)

@electroglyph

Copy link
Copy Markdown
Contributor Author

If you can address the other comments, then I'll be able to approve the PR :)

okay, sorry for disappearing and thanks for your time!

@Datta0 Datta0 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@mmathew23

Copy link
Copy Markdown
Contributor

Yes looks good to me as well.

@Datta0 Datta0 merged commit ab4061e into unslothai:main Jan 16, 2026
1 check passed
abiswas-realadvice pushed a commit to abiswas-realadvice/unsloth that referenced this pull request May 14, 2026
…nslothai#3859)

* add int8 weight-only QAT scheme, add test, fix tests for current torchao version

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* change quantization to PerAxis

* lambda =/

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add torchao messages, remove group_size from int8

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* raise exception on missing torchao

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* touch up the torchao imports

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature] test_qat.py fails in torchao 0.15.0 [Feature] QAT scheme: A16W8 Int8 WeightOnly Quantization

3 participants