Skip to content

enable smoothquant for int8 static tensor#3468

Merged
jcaip merged 41 commits into
mainfrom
jcaip/enable-smoothquant
Jan 21, 2026
Merged

enable smoothquant for int8 static tensor#3468
jcaip merged 41 commits into
mainfrom
jcaip/enable-smoothquant

Conversation

@jcaip

@jcaip jcaip commented Dec 8, 2025

Copy link
Copy Markdown
Contributor

This PR hooks up the static quant workflow added in #3442 to the prototype smoothquant API.

You can use the new flow like follows:

from torchao.quantization.quant_api import (
    Int8StaticActivationInt8WeightConfig,
)
from torchao.prototype.smoothquant import (
    SmoothQuantConfig
)

config = SmoothQuantConfig(
            base_config=Int8StaticActivationInt8Weight(granularity=PerRow()),
            step=SmoothQuantStep.PREPARE,
            alpha=0.5,
        )

quantize_(model, config)

# Perform calibration with test data
model(*x)

config.step = SmoothQuantStep.CONVERT
quantize_(model, config)

# model will now be statically quantized with the inputs used in smoothquant observer. 
model(*x)

Test Plan:

pytest test/quantization/quantize_/workflows/int8/test_int8_tensor.py
pytest test/prototype/test_smoothquant.py 

Comment thread test/quantization/quantize_/workflows/int8/test_int8_tensor.py Outdated
Comment thread torchao/prototype/smoothquant/api.py Outdated
else:
raise ValueError(f"Unexpected step: {step}")

if isinstance(base_config, Int8StaticActivationInt8WeightConfig):

@jerryzh168 jerryzh168 Dec 18, 2025

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we shouldn't have specific config here, maybe change this to a similar protocol like SupportsActivationPreScaling for config?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think figuring out how to do this generally will need a bit more design, we'd need to figure out how to map to the appropriate QuantizeTensorToInt/FloatXKwargs object. Agree we should be able to do this though, but can I address in a later PR?

@jerryzh168 jerryzh168 Jan 20, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can get this info from a callback I think, like base_config.get_activation_quant_kwargs(), I feel you can iterating from that, but not from the current state

@jcaip jcaip force-pushed the jcaip/enable-smoothquant branch from f389a94 to 2586ab6 Compare December 18, 2025 00:02
block_size,
self.dtype,
act_quant_kwargs=self.act_quant_kwargs,
act_scale=self.act_scale,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess slice doesn't work for static quant int8 before, can you add a test for that?

old_int8_tensor.scale[index],
old_int8_tensor.block_size[1:],
old_int8_tensor.dtype,
old_int8_tensor.act_scale,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same for this one, seems like select op breaks before with static quant

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For slice and select, they work when granularity is PerDim(=-1), but not otherwise. we now throw an exception when passing PerDim != -1 and I added a test for both.

Comment thread torchao/quantization/quantize_/workflows/int8/int8_tensor.py Outdated
@jcaip jcaip force-pushed the jcaip/enable-smoothquant branch from ea9b8e2 to c3abc46 Compare January 20, 2026 20:45
This PR hooks up the static quant workflow added in #3442 to the prototype smoothquant API.

You can use the new flow like follows:

```python
from torchao.quantization.quant_api import (
    Int8StaticActivationInt8WeightConfig,
)
from torchao.prototype.smoothquant import (
    SmoothQuantConfig
)

config = SmoothQuantConfig(
            base_config=Int8StaticActivationInt8Weight(granularity=PerRow()),
            step=SmoothQuantStep.PREPARE,
            alpha=0.5,
        )

quantize_(model, config)

# Perform calibration with test data
model(*x)

config.step = SmoothQuantStep.CONVERT
quantize_(model, config)

# model will now be statically quantized with the inputs used in smoothquant observer.
model(*x)
```
@jcaip jcaip force-pushed the jcaip/enable-smoothquant branch from c3abc46 to 7b45e3e Compare January 20, 2026 20:53
@jcaip jcaip requested a review from jerryzh168 January 20, 2026 21:06
Comment thread test/quantization/quantize_/workflows/int8/test_int8_tensor.py
Comment thread torchao/prototype/smoothquant/api.py Outdated
Comment on lines +126 to +127
if isinstance(base_config, Int8StaticActivationInt8WeightConfig):
base_config.static_scale = activation_scale

@jerryzh168 jerryzh168 Jan 20, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can do the same as SupportsActivationPreScaling, like IsStaticQuantizationConfig

Comment thread torchao/quantization/quant_api.py Outdated
Comment thread torchao/quantization/quant_api.py Outdated
@jcaip jcaip force-pushed the jcaip/enable-smoothquant branch from 8764cbe to 2f05d2c Compare January 20, 2026 22:23

@jerryzh168 jerryzh168 left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@jerryzh168 jerryzh168 left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@jcaip jcaip enabled auto-merge (squash) January 21, 2026 17:30
@jcaip jcaip disabled auto-merge January 21, 2026 17:30
@jcaip jcaip enabled auto-merge (squash) January 21, 2026 17:32
@jcaip jcaip disabled auto-merge January 21, 2026 17:53
@jcaip jcaip closed this Jan 21, 2026
@jcaip jcaip reopened this Jan 21, 2026
@jcaip jcaip merged commit c6bc74c into main Jan 21, 2026
25 of 39 checks passed
jcaip added a commit that referenced this pull request Jan 22, 2026
* Int8Tensor migration

Summary:

This PR creates a new Int8Tensor and updates the configs to use the new
Int8Tensor flow

Test Plan:

To ensure BC:
```
pytest test/quantization/test_quant_api.py
```

To test new Int8Tensor:
```
pytest test/quantization/quantize_/workflows/int8/test_int8_tensor.py
```

Reviewers:

Subscribers:

Tasks:

Tags:

* ruff fixes

* add init

* fix ruff again

* update

* wip

* undo update tests

* fix ruff

* fix varname

* fix typing

* add tests

* fix dtype

* fix ci

* address granularity cr

* update _choose_quant_func_and_quantize_tensor

* make block size required attribute

* made dtype required as well

* address nits

* skip per tensor weight only test for now

* add static quant

* add static quant

* update

* static quant working eager + compile

* remove file

* added asserts

* undo smoothquant change

* fix return

* got smoothquant + int8 static working

* generalized smoothquat code

* free tests

* fix static scale check

* update

* address cr feedback

* Hook up static quant workflow to prototype smoothquant API

This PR hooks up the static quant workflow added in #3442 to the prototype smoothquant API.

You can use the new flow like follows:

```python
from torchao.quantization.quant_api import (
    Int8StaticActivationInt8WeightConfig,
)
from torchao.prototype.smoothquant import (
    SmoothQuantConfig
)

config = SmoothQuantConfig(
            base_config=Int8StaticActivationInt8Weight(granularity=PerRow()),
            step=SmoothQuantStep.PREPARE,
            alpha=0.5,
        )

quantize_(model, config)

# Perform calibration with test data
model(*x)

config.step = SmoothQuantStep.CONVERT
quantize_(model, config)

# model will now be statically quantized with the inputs used in smoothquant observer.
model(*x)
```

* fix ruff

* fix test to use threshold for sqnr
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants