Skip to content

Add new QAT API through quantize_#1415

Closed
andrewor14 wants to merge 1 commit into
mainfrom
new-qat-api
Closed

Add new QAT API through quantize_#1415
andrewor14 wants to merge 1 commit into
mainfrom
new-qat-api

Conversation

@andrewor14

Copy link
Copy Markdown
Contributor

Summary: This commit adds a new QAT API that can be used with the existing quantize_. This is an alternative to the old QAT *Quantizer APIs, which are much less flexible. The new API can be used as follows:

from torchao import quantize_
from torchao.quantization.qat import (
    FakeQuantizeConfig,
    intx_quantization_aware_training,
)
my_model = ...
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
quantize_(
    my_model,
    intx_quantization_aware_training(activation_config, weight_config),
)

Test Plan:
python test/quantization/test_qat.py -k test_quantize_api

@pytorch-bot

pytorch-bot Bot commented Dec 13, 2024

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1415

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit eb0d868 with merge base ebc4303 (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 13, 2024
@andrewor14 andrewor14 force-pushed the new-qat-api branch 2 times, most recently from 6118d76 to 6190af7 Compare December 13, 2024 17:32
@andrewor14 andrewor14 added the topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) label Dec 13, 2024
Comment thread torchao/quantization/qat/embedding.py
elif isinstance(mod, torch.nn.Embedding):
if activation_config is not None:
raise ValueError("Embedding does not support QAT for activations")
return FakeQuantizedEmbedding.from_embedding(mod, weight_config)

@jerryzh168 jerryzh168 Dec 13, 2024

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the function takes both activation_config and weight_config, but in this case we only use weight_config, this might be confusing

should we be separating these two branches to 2 APIs?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel it's better to only have one API, otherwise we'll need a separate API for each type of layer we support. Also it's more consistent with PTQ where we only have one int8_weight_only and let the filter_fn decide which layers to apply the transformation on (instead of having separate int8_weight_only_linear and int8_weight_only_embedding)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh the separation I'm referring to is to separate int8_act_int4_weight and int8_weight_only here, not linear v.s. embedding

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I see, you mean separate this into the following?

intx_activation_intx_weight_quantization_aware_training
intx_weight_only_quantization_aware_training

I can do that, but the naming seems a bit long... any suggestions?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, separated into these two for now. Please let me know if you have better suggestions for the naming:

intx_quantization_aware_training (for both act + weight)
intx_weight_only_quantization_aware_training

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe intx_quantization_aware_training can be intx_dynamic_quantization_aware_training?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or just expand to match ptq and then add a qat as suffix, that might be clearer?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed this offline: we settled on having a single intx_quantization_aware_training (not separating into activation+weight and weight-only), which is the most general and can be extended to support output activation in the future. Having "dynamic" in the name is also not great because we also support FakeQuantizeConfig(is_dynamic=False).

Since QAT only supports linear and embedding, and embedding only supports weight-only QAT, we will support the following:

# ok
quantize_(model, intx_quantization_aware_training(act=config1, weight=config2), is_linear)
quantize_(model, intx_quantization_aware_training(weight=config3), is_embedding)

# throws an exception
quantize_(model, intx_quantization_aware_training(act=config1, weight=config2), is_embedding)
quantize_(model, intx_quantization_aware_training(act=config1, weight=config2), is_conv)

@andrewor14 andrewor14 force-pushed the new-qat-api branch 2 times, most recently from 8293026 to c887099 Compare December 13, 2024 21:23
@andrewor14 andrewor14 force-pushed the new-qat-api branch 3 times, most recently from afbac81 to 2666b80 Compare December 13, 2024 22:41
Summary: This commit adds a new QAT API that can be used with
the existing `quantize_`. This is an alternative to the old
QAT *Quantizer APIs, which are much less flexible. The new API
can be used as follows:

```
from torchao import quantize_
from torchao.quantization.qat import (
    FakeQuantizeConfig,
    intx_quantization_aware_training,
)
my_model = ...
activation_config = FakeQuantizeConfig(
    torch.int8, "per_token", is_symmetric=False,
)
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
quantize_(
    my_model,
    intx_quantization_aware_training(activation_config, weight_config),
)
```

Test Plan:
python test/quantization/test_qat.py -k test_quantize_api
python test/quantization/test_qat.py -k test_quantize_api_errors
@andrewor14

Copy link
Copy Markdown
Contributor Author

Thanks, merging this!

@andrewor14

Copy link
Copy Markdown
Contributor Author

@pytorchbot merge

@pytorchmergebot

Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

andrewor14 added a commit that referenced this pull request Jan 10, 2025
Summary: #1415 added a quantize_
QAT API for the prepare path. This commit adds the remaining
convert path for users to actually perform end-to-end QAT using
the quantize_ API. The new flow will look like:

```
from torchao.quantization import (
    quantize_,
    int8_dynamic_activation_int4_weight,
)
from torchao.quantization.qat import (
    FakeQuantizeConfig,
    from_intx_quantization_aware_training,
    intx_quantization_aware_training,
)

activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
quantize_(
    my_model,
    intx_quantization_aware_training(activation_config, weight_config),
)

quantize_(my_model, from_intx_quantization_aware_training())
quantize_(my_model, int8_dynamic_activation_int4_weight(group_size=32))
```

Test Plan:
python test/quantization/test_qat.py -k test_quantize_api_convert_path
andrewor14 added a commit that referenced this pull request Jan 10, 2025
Summary: #1415 added a quantize_
QAT API for the prepare path. This commit adds the remaining
convert path for users to actually perform end-to-end QAT using
the quantize_ API. The new flow will look like:

```
from torchao.quantization import (
    quantize_,
    int8_dynamic_activation_int4_weight,
)
from torchao.quantization.qat import (
    FakeQuantizeConfig,
    from_intx_quantization_aware_training,
    intx_quantization_aware_training,
)

activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
quantize_(
    my_model,
    intx_quantization_aware_training(activation_config, weight_config),
)

quantize_(my_model, from_intx_quantization_aware_training())
quantize_(my_model, int8_dynamic_activation_int4_weight(group_size=32))
```

Test Plan:
python test/quantization/test_qat.py -k test_quantize_api_convert_path

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Jan 10, 2025
Summary: #1415 added a quantize_
QAT API for the prepare path. This commit adds the remaining
convert path for users to actually perform end-to-end QAT using
the quantize_ API. The new flow will look like:

```
from torchao.quantization import (
    quantize_,
    int8_dynamic_activation_int4_weight,
)
from torchao.quantization.qat import (
    FakeQuantizeConfig,
    from_intx_quantization_aware_training,
    intx_quantization_aware_training,
)

activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
quantize_(
    my_model,
    intx_quantization_aware_training(activation_config, weight_config),
)

quantize_(my_model, from_intx_quantization_aware_training())
quantize_(my_model, int8_dynamic_activation_int4_weight(group_size=32))
```

Test Plan:
python test/quantization/test_qat.py -k test_quantize_api_convert_path

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Jan 10, 2025
Summary: #1415 added a quantize_
QAT API for the prepare path. This commit adds the remaining
convert path for users to actually perform end-to-end QAT using
the quantize_ API. The new flow will look like:

```
from torchao.quantization import (
    quantize_,
    int8_dynamic_activation_int4_weight,
)
from torchao.quantization.qat import (
    FakeQuantizeConfig,
    from_intx_quantization_aware_training,
    intx_quantization_aware_training,
)

activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
quantize_(
    my_model,
    intx_quantization_aware_training(activation_config, weight_config),
)

quantize_(my_model, from_intx_quantization_aware_training())
quantize_(my_model, int8_dynamic_activation_int4_weight(group_size=32))
```

Test Plan:
python test/quantization/test_qat.py -k test_quantize_api_convert_path

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Jan 10, 2025
Summary: #1415 added a quantize_
QAT API for the prepare path. This commit adds the remaining
convert path for users to actually perform end-to-end QAT using
the quantize_ API. The new flow will look like:

```
from torchao.quantization import (
    quantize_,
    int8_dynamic_activation_int4_weight,
)
from torchao.quantization.qat import (
    FakeQuantizeConfig,
    from_intx_quantization_aware_training,
    intx_quantization_aware_training,
)

activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
quantize_(
    my_model,
    intx_quantization_aware_training(activation_config, weight_config),
)

quantize_(my_model, from_intx_quantization_aware_training())
quantize_(my_model, int8_dynamic_activation_int4_weight(group_size=32))
```

Test Plan:
python test/quantization/test_qat.py -k test_quantize_api_convert_path

[ghstack-poisoned]
amdfaa pushed a commit that referenced this pull request Jan 10, 2025
**Summary:** This commit adds a new QAT API that can be used with the existing `quantize_`. This is an alternative to the old QAT *Quantizer APIs, which are much less flexible. The new API can be used as follows:

```
from torchao import quantize_
from torchao.quantization.qat import (
    FakeQuantizeConfig,
    intx_quantization_aware_training,
)
my_model = ...
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
quantize_(
    my_model,
    intx_quantization_aware_training(activation_config, weight_config),
)
```

**Test Plan:**
python test/quantization/test_qat.py -k test_quantize_api
Pull Request resolved: #1415
Approved by: https://github.com/jerryzh168
andrewor14 added a commit that referenced this pull request Jan 13, 2025
* Add convert path for quantize_ QAT API

Summary: #1415 added a quantize_
QAT API for the prepare path. This commit adds the remaining
convert path for users to actually perform end-to-end QAT using
the quantize_ API. The new flow will look like:

```
from torchao.quantization import (
    quantize_,
    int8_dynamic_activation_int4_weight,
)
from torchao.quantization.qat import (
    FakeQuantizeConfig,
    from_intx_quantization_aware_training,
    intx_quantization_aware_training,
)

activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
quantize_(
    my_model,
    intx_quantization_aware_training(activation_config, weight_config),
)

quantize_(my_model, from_intx_quantization_aware_training())
quantize_(my_model, int8_dynamic_activation_int4_weight(group_size=32))
```

Test Plan:
python test/quantization/test_qat.py -k test_quantize_api_convert_path

[ghstack-poisoned]

* Update on "Add convert path for quantize_ QAT API"


Summary: #1415 added a quantize_
QAT API for the prepare path. This commit adds the remaining
convert path for users to actually perform end-to-end QAT using
the quantize_ API. The new flow will look like:

```
from torchao.quantization import (
    quantize_,
    int8_dynamic_activation_int4_weight,
)
from torchao.quantization.qat import (
    FakeQuantizeConfig,
    from_intx_quantization_aware_training,
    intx_quantization_aware_training,
)

activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
quantize_(
    my_model,
    intx_quantization_aware_training(activation_config, weight_config),
)

quantize_(my_model, from_intx_quantization_aware_training())
quantize_(my_model, int8_dynamic_activation_int4_weight(group_size=32))
```

Test Plan:
python test/quantization/test_qat.py -k test_quantize_api_convert_path

[ghstack-poisoned]

* Update on "Add convert path for quantize_ QAT API"


Summary: #1415 added a quantize_
QAT API for the prepare path. This commit adds the remaining
convert path for users to actually perform end-to-end QAT using
the quantize_ API. The new flow will look like:

```
from torchao.quantization import (
    quantize_,
    int8_dynamic_activation_int4_weight,
)
from torchao.quantization.qat import (
    FakeQuantizeConfig,
    from_intx_quantization_aware_training,
    intx_quantization_aware_training,
)

activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
quantize_(
    my_model,
    intx_quantization_aware_training(activation_config, weight_config),
)

quantize_(my_model, from_intx_quantization_aware_training())
quantize_(my_model, int8_dynamic_activation_int4_weight(group_size=32))
```

Test Plan:
python test/quantization/test_qat.py -k test_quantize_api_convert_path

[ghstack-poisoned]

* Update on "Add convert path for quantize_ QAT API"


Summary: #1415 added a quantize_
QAT API for the prepare path. This commit adds the remaining
convert path for users to actually perform end-to-end QAT using
the quantize_ API. The new flow will look like:

```
from torchao.quantization import (
    quantize_,
    int8_dynamic_activation_int4_weight,
)
from torchao.quantization.qat import (
    FakeQuantizeConfig,
    from_intx_quantization_aware_training,
    intx_quantization_aware_training,
)

activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
quantize_(
    my_model,
    intx_quantization_aware_training(activation_config, weight_config),
)

quantize_(my_model, from_intx_quantization_aware_training())
quantize_(my_model, int8_dynamic_activation_int4_weight(group_size=32))
```

Test Plan:
python test/quantization/test_qat.py -k test_quantize_api_convert_path

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Jan 13, 2025
* Add convert path for quantize_ QAT API

Summary: #1415 added a quantize_
QAT API for the prepare path. This commit adds the remaining
convert path for users to actually perform end-to-end QAT using
the quantize_ API. The new flow will look like:

```
from torchao.quantization import (
    quantize_,
    int8_dynamic_activation_int4_weight,
)
from torchao.quantization.qat import (
    FakeQuantizeConfig,
    from_intx_quantization_aware_training,
    intx_quantization_aware_training,
)

activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
quantize_(
    my_model,
    intx_quantization_aware_training(activation_config, weight_config),
)

quantize_(my_model, from_intx_quantization_aware_training())
quantize_(my_model, int8_dynamic_activation_int4_weight(group_size=32))
```

Test Plan:
python test/quantization/test_qat.py -k test_quantize_api_convert_path

[ghstack-poisoned]

* Update QAT READMEs using new APIs

Add references to new QAT APIs including `quantize_`,
`FakeQuantizedX`, and the new embedding Quantizers and
ComposableQATQuantizer. Also link to new QAT + LoRA recipe
in torchtune.

[ghstack-poisoned]

* Update base for Update on "Update QAT READMEs using new APIs"


Add references to new QAT APIs including `quantize_`,
`FakeQuantizedX`, and the new embedding Quantizers and
ComposableQATQuantizer. Also link to new QAT + LoRA recipe
in torchtune.

[ghstack-poisoned]

* Update base for Update on "Update QAT READMEs using new APIs"


Add references to new QAT APIs including `quantize_`,
`FakeQuantizedX`, and the new embedding Quantizers and
ComposableQATQuantizer. Also link to new QAT + LoRA recipe
in torchtune.

[ghstack-poisoned]

* Update base for Update on "Update QAT READMEs using new APIs"


Add references to new QAT APIs including `quantize_`,
`FakeQuantizedX`, and the new embedding Quantizers and
ComposableQATQuantizer. Also link to new QAT + LoRA recipe
in torchtune.

[ghstack-poisoned]

* Update base for Update on "Update QAT READMEs using new APIs"


Add references to new QAT APIs including `quantize_`,
`FakeQuantizedX`, and the new embedding Quantizers and
ComposableQATQuantizer. Also link to new QAT + LoRA recipe
in torchtune.

[ghstack-poisoned]

* Update base for Update on "Update QAT READMEs using new APIs"


Add references to new QAT APIs including `quantize_`,
`FakeQuantizedX`, and the new embedding Quantizers and
ComposableQATQuantizer. Also link to new QAT + LoRA recipe
in torchtune.

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Merged topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants