enable smoothquant for int8 static tensor#3468
Conversation
Summary: This PR creates a new Int8Tensor and updates the configs to use the new Int8Tensor flow Test Plan: To ensure BC: ``` pytest test/quantization/test_quant_api.py ``` To test new Int8Tensor: ``` pytest test/quantization/quantize_/workflows/int8/test_int8_tensor.py ``` Reviewers: Subscribers: Tasks: Tags:
| else: | ||
| raise ValueError(f"Unexpected step: {step}") | ||
|
|
||
| if isinstance(base_config, Int8StaticActivationInt8WeightConfig): |
There was a problem hiding this comment.
I think we shouldn't have specific config here, maybe change this to a similar protocol like SupportsActivationPreScaling for config?
There was a problem hiding this comment.
I think figuring out how to do this generally will need a bit more design, we'd need to figure out how to map to the appropriate QuantizeTensorToInt/FloatXKwargs object. Agree we should be able to do this though, but can I address in a later PR?
There was a problem hiding this comment.
we can get this info from a callback I think, like base_config.get_activation_quant_kwargs(), I feel you can iterating from that, but not from the current state
f389a94 to
2586ab6
Compare
| block_size, | ||
| self.dtype, | ||
| act_quant_kwargs=self.act_quant_kwargs, | ||
| act_scale=self.act_scale, |
There was a problem hiding this comment.
I guess slice doesn't work for static quant int8 before, can you add a test for that?
| old_int8_tensor.scale[index], | ||
| old_int8_tensor.block_size[1:], | ||
| old_int8_tensor.dtype, | ||
| old_int8_tensor.act_scale, |
There was a problem hiding this comment.
same for this one, seems like select op breaks before with static quant
There was a problem hiding this comment.
For slice and select, they work when granularity is PerDim(=-1), but not otherwise. we now throw an exception when passing PerDim != -1 and I added a test for both.
ea9b8e2 to
c3abc46
Compare
This PR hooks up the static quant workflow added in #3442 to the prototype smoothquant API. You can use the new flow like follows: ```python from torchao.quantization.quant_api import ( Int8StaticActivationInt8WeightConfig, ) from torchao.prototype.smoothquant import ( SmoothQuantConfig ) config = SmoothQuantConfig( base_config=Int8StaticActivationInt8Weight(granularity=PerRow()), step=SmoothQuantStep.PREPARE, alpha=0.5, ) quantize_(model, config) # Perform calibration with test data model(*x) config.step = SmoothQuantStep.CONVERT quantize_(model, config) # model will now be statically quantized with the inputs used in smoothquant observer. model(*x) ```
c3abc46 to
7b45e3e
Compare
| if isinstance(base_config, Int8StaticActivationInt8WeightConfig): | ||
| base_config.static_scale = activation_scale |
There was a problem hiding this comment.
can do the same as SupportsActivationPreScaling, like IsStaticQuantizationConfig
8764cbe to
2f05d2c
Compare
* Int8Tensor migration Summary: This PR creates a new Int8Tensor and updates the configs to use the new Int8Tensor flow Test Plan: To ensure BC: ``` pytest test/quantization/test_quant_api.py ``` To test new Int8Tensor: ``` pytest test/quantization/quantize_/workflows/int8/test_int8_tensor.py ``` Reviewers: Subscribers: Tasks: Tags: * ruff fixes * add init * fix ruff again * update * wip * undo update tests * fix ruff * fix varname * fix typing * add tests * fix dtype * fix ci * address granularity cr * update _choose_quant_func_and_quantize_tensor * make block size required attribute * made dtype required as well * address nits * skip per tensor weight only test for now * add static quant * add static quant * update * static quant working eager + compile * remove file * added asserts * undo smoothquant change * fix return * got smoothquant + int8 static working * generalized smoothquat code * free tests * fix static scale check * update * address cr feedback * Hook up static quant workflow to prototype smoothquant API This PR hooks up the static quant workflow added in #3442 to the prototype smoothquant API. You can use the new flow like follows: ```python from torchao.quantization.quant_api import ( Int8StaticActivationInt8WeightConfig, ) from torchao.prototype.smoothquant import ( SmoothQuantConfig ) config = SmoothQuantConfig( base_config=Int8StaticActivationInt8Weight(granularity=PerRow()), step=SmoothQuantStep.PREPARE, alpha=0.5, ) quantize_(model, config) # Perform calibration with test data model(*x) config.step = SmoothQuantStep.CONVERT quantize_(model, config) # model will now be statically quantized with the inputs used in smoothquant observer. model(*x) ``` * fix ruff * fix test to use threshold for sqnr
This PR hooks up the static quant workflow added in #3442 to the prototype smoothquant API.
You can use the new flow like follows:
Test Plan: