[CPU] Enable DA8W4 on CPU#2128
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2128
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit e3731f7 with merge base 4ebc9c0 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@leslie-fang-intel This PR is updated to use a new layout. Please review again. Thanks. |
|
Hi @jerryzh168 Could you please review this PR? Thanks. |
2 similar comments
|
Hi @jerryzh168 Could you please review this PR? Thanks. |
|
Hi @jerryzh168 Could you please review this PR? Thanks. |
leslie-fang-intel
left a comment
There was a problem hiding this comment.
Please also describe how we choose different implementations based on the CPU Info.
I have added more details in the description. Thanks. |
|
Hi @jerryzh168 Could you please review this PR? Thanks. It's changed a lot since your last review. |
|
Hi @jerryzh168 Could you please review this PR? Thanks. |
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class Int8DynamicActInt4WeightCPULayout(Layout): |
There was a problem hiding this comment.
it looks like you can just reuse Int4CPULayout
There was a problem hiding this comment.
can you move the layout and impl to a separate file?
|
|
||
|
|
||
| @register_layout(Int8DynamicActInt4WeightCPULayout) | ||
| class DA8W4CPUAQTTensorImpl(Int4CPUAQTTensorImpl): |
There was a problem hiding this comment.
oh I see, OK if you need a separate Impl then makes sense to have a separate layout
There was a problem hiding this comment.
Yes. We need a different impl from W16W4 because the ISA (AMX and VNNI) requires different memory formats of weight for computation in BF16 or INT8. Thanks.
| int_data = (int_data + 8).to(torch.uint8) | ||
| if scale.dim() == 1: | ||
| scale.unsqueeze_(-1) | ||
| scale = scale.to(torch.float) | ||
| if zero_point.dim() == 1: | ||
| zero_point.unsqueeze_(-1) | ||
| zero_point = zero_point.to(torch.int8) + 8 |
There was a problem hiding this comment.
can you configure dtypes of int_data, scale, zero_point and shapes in the call to to_affine_quantized_intx?
There was a problem hiding this comment.
Thanks for the suggestion. I have improved this part.
| quant_min = -8 | ||
| quant_max = 7 | ||
|
|
||
| if isinstance(layout, Int8DynamicActInt4WeightCPULayout): |
There was a problem hiding this comment.
can this happen in kernel? we have dtype conversions like this:
ao/torchao/dtypes/uintx/plain_layout.py
Line 260 in 2898903
There was a problem hiding this comment.
Thanks for the comment. I have moved this to _linear_int8_act_int4_weight_cpu_impl.
* [CPU] enable int8_dynamic_activation_int4_weight with Int4CPULayout * Fix format issue * Add Int8DynamicActInt4WeightCPULayout * remove dispatch for t() * Add cpp kernel for weight packing and GEMM * Register ATQ linear dispatch for da8w4 linear * Fix issues with torch.compile * Fix DA8W4CPUAQTTensorImpl.get_plain * Test DA8W4CPUAQTTensorImpl.get_plain in UT * Skip UT if CPP kernel not built * Add AVX512_VNNI implementation for small M * improve performance * Support symmetric quantization of activation * Refine code * Refine code * Put in a separate file * Bug fix * refine code
Summary
This PR enables DA8W4 on CPU.
Int8DynamicActInt4WeightCPULayoutand its implementationda8w4_linear_prepack_cpufor weight packingda8w4_linear_cpufor A8W4 GEMM.The feature supports symmetric and asymmetric quantization of activation.
The ops and kernels won't be available unless
USE_CPP_KERNELS=1on Linux with an X86 CPU with AVX512.To get the best performance, one needs a CPU with AMX support.
Implementation details
at::cpublasbrgemm utilities from Pytorch core if available.Usage
Test plan