Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2576
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 6bdd3f6 with merge base 2eb4f97 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| quant_min=quant_min, | ||
| quant_max=quant_max, | ||
| ) | ||
| data, scale, zero_point = _layout.post_process( |
There was a problem hiding this comment.
IIUC, the zero_point isn't some fixed fp value, but can vary slightly based on the ranges.
So what the CPU kernel can support is a fp32 scale times a int8-LUT value.
So if we want the grid {-3.5, -1.5, 1.5, 3.5}, we instead use the LUT = {-7, -3, 3, 7} with s_table = 0.5.
In addition to having this LUT for the whole tensor, we can have an FP32 scale s at a per-group granularity. Dequantization is then:
w_dequantized = s * s_table * LUT[idx]
It's not clear to be that the affine scheme here is representable in that way. IIUC, you have:
w_dequantized = s * (qval - z)
So qval - z could define the LUT, but it looks like we'd have a different LUT per group_size values because z changes every group_size values? Is that right?
There was a problem hiding this comment.
Thanks for the comments @metascroy!
Just to clarify from our chat earlier, zero_point=-0.5 is the same across all groups. (I flipped the sign since it's standard to add zero_point during quantization.)
w_quantized = torch.round(x / s + zero_point)
w_dequantized = s * (w_quantized - zero_point)For the 2-bit case, we set s so that x / s is restricted to range [-1.5, 1.5]. Since zero_point=-0.5, w_quantized lies in the grid {-2, -1, 0, 1}.
It seems like we don't need to use LUT format since this is well-supported by an affine scheme. Maybe it would be worth supporting for latency comparisons though (and to avoid float zero_point).
There was a problem hiding this comment.
It is well supported by an affine scheme where zero_point is a float, but we do not have CPU kernel support for this.
But if zero_point is always 0.5, then w_quantized - zero_point is just some value in [1.5, 1.5], and this could define an LUT, so I think we can hook into the kernel in that way.
We just need the LUT to be integer, so we can define the LUT as [-3, -1, 1, 3] and then divide the scales in half.
| compare_parq_convert(model, m_ref, optimizer, config) | ||
|
|
||
|
|
||
| class TestStretchedUnifTorchaoQuantizer(common_utils.TestCase): |
There was a problem hiding this comment.
New test case that ensures equivalence between PARQ's original UnifQuantizer implementation and the new StretchedUnifTorchaoQuantizer
| q_abs = input_float.abs() | ||
| max_val = torch.minimum( | ||
| b * q_abs.mean(dim=reduction_dims, keepdim=True), | ||
| torch.amax(q_abs, dim=reduction_dims, keepdim=True), | ||
| ).clamp_(min=eps) | ||
|
|
||
| scale = max_val / quant_max | ||
| scale = scale.to(dtype=scale_dtype, device=input_float.device) | ||
| zero_point = torch.full_like(scale, -0.5, dtype=zero_point_dtype) |
There was a problem hiding this comment.
Here's the logic for initializing the scale based on multiples of per-group absolute value means. I also manually set the zero point to be the same across groups.
metascroy
left a comment
There was a problem hiding this comment.
Looks good to me. We can translate the affine scheme to LUT when we prepare the data for the kernels.
* Add StretchedUnifTorchaoQuantizer * Fix tinygemm test case * Test equivalence to PARQ UnifQuantizer; custom choose_qparams, quantize, dequantize * Remove dequantize_stretched_affine
This PR adds a new stretched uniform quantizer for PARQ, which empirically performs well for 2- and 3-bit QAT. Main differences:
quant_min=-2**(b - 1) + 0.5andquant_max=2**(b - 1) - 0.5valuesmin_val,max_valare computed by taking a multiple of the mean over absolute values (instead of absmax)As in #2091, I also compare the resulting PARQ quantized weights with those quantized with torchao's module swap +
quantize_API. To support this, I created a new tensor subclassStretchedAffineQuantizedTensorand configStretchedIntxWeightOnlyConfigto handle floating pointquant_min,quant_max, andzero_pointvalues.