Add torchao int8 da8w8 sym act sym wgt linear pattern for CPU#141851
Add torchao int8 da8w8 sym act sym wgt linear pattern for CPU#141851sanchitintel wants to merge 13 commits intogh/sanchitintel/2/basefrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/141851
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit b210432 with merge base 5c2584a ( UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Hi @leslie-fang-intel, can you please review this PR? Thanks! |
| ), | ||
| KeywordArg("b"), | ||
| aten.reshape.default, | ||
| KeywordArg("a"), |
There was a problem hiding this comment.
May I know why this pattern must have a reshape node at activation? It seems we assume the input is 3D.
There was a problem hiding this comment.
It's not 3D in torchao, at least for da8w8.
convert_element_type_2: "i8[32, 32][32, 1]cpu" = torch.ops.prims.convert_element_type.default(clamp_max_1, torch.int8); clamp_max_1 = None
view_15: "i8[32, 32][32, 1]cpu" = torch.ops.aten.reshape.default(convert_element_type_2, [-1, 32]); convert_element_type_2 = None
permute_1: "i8[32, 32][1, 32]cpu" = torch.ops.aten.permute.default(arg5_1, [1, 0]); arg5_1 = None
_int_mm_1: "i32[32, 32][32, 1]cpu" = torch.ops.aten._int_mm.default(view_15, permute_1); view_15 = permute_1 = None
There was a problem hiding this comment.
Then why we register the pattern with reshape at input of activation but without reshape at out?
There was a problem hiding this comment.
That's just how the torchao pattern is with simple int8_dynamic_activation_int8_weight UTs I tested - they don't have reshape at out.
def forward(self, arg0_1: "i8[32, 64][64, 1]cpu", arg1_1: "f32[32][1]cpu", arg2_1: "i64[32][1]cpu", arg3_1: "f32[32][1]cpu", arg4_1: "f32[32, 64][64, 1]cpu", arg5_1: "i8[32, 32][32, 1]cpu", arg6_1: "f32[32][1]cpu", arg7_1: "i64[32][1]cpu", arg8_1: "f32[32][1]cpu"):
# File: /localdisk/sanchitj/smoothquant/torch/_dynamo/external_utils.py:31 in inner, code: return fn(*args, **kwargs)
full_2: "f32[32][1]cpu" = torch.ops.aten.full.default([32], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False); full_2 = None
full_default_2: "f32[32, 1][1, 1]cpu" = torch.ops.aten.full.default([32, 1], 0.0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False); full_default_2 = None
full_5: "f32[32][1]cpu" = torch.ops.aten.full.default([32], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False); full_5 = None
full_default_5: "f32[32, 1][1, 1]cpu" = torch.ops.aten.full.default([32, 1], 0.0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False); full_default_5 = None
amin: "f32[32][1]cpu" = torch.ops.aten.amin.default(arg4_1, [1])
full_default: "f32[32][1]cpu" = torch.ops.aten.full.default([32], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
minimum: "f32[32][1]cpu" = torch.ops.aten.minimum.default(amin, full_default); amin = full_default = None
neg: "f32[32][1]cpu" = torch.ops.aten.neg.default(minimum); minimum = None
amax: "f32[32][1]cpu" = torch.ops.aten.amax.default(arg4_1, [1])
full_default_1: "f32[32][1]cpu" = torch.ops.aten.full.default([32], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
maximum: "f32[32][1]cpu" = torch.ops.aten.maximum.default(amax, full_default_1); amax = full_default_1 = None
maximum_1: "f32[32][1]cpu" = torch.ops.aten.maximum.default(neg, maximum); neg = maximum = None
div: "f32[32][1]cpu" = torch.ops.aten.div.Tensor(maximum_1, 127.0); maximum_1 = None
clamp_min: "f32[32][1]cpu" = torch.ops.aten.clamp_min.default(div, 1e-05); div = None
view_2: "f32[32, 1][1, 1]cpu" = torch.ops.aten.reshape.default(clamp_min, [32, 1])
reciprocal: "f32[32, 1][1, 1]cpu" = torch.ops.aten.reciprocal.default(view_2); view_2 = None
mul: "f32[32, 1][1, 1]cpu" = torch.ops.aten.mul.Tensor(reciprocal, 1.0); reciprocal = None
mul_1: "f32[32, 64][64, 1]cpu" = torch.ops.aten.mul.Tensor(arg4_1, mul); arg4_1 = mul = None
round_1: "f32[32, 64][64, 1]cpu" = torch.ops.aten.round.default(mul_1); mul_1 = None
clamp_min_1: "f32[32, 64][64, 1]cpu" = torch.ops.aten.clamp_min.default(round_1, -127); round_1 = None
clamp_max: "f32[32, 64][64, 1]cpu" = torch.ops.aten.clamp_max.default(clamp_min_1, 127); clamp_min_1 = None
convert_element_type: "i8[32, 64][64, 1]cpu" = torch.ops.prims.convert_element_type.default(clamp_max, torch.int8); clamp_max = None
view_5: "i8[32, 64][64, 1]cpu" = torch.ops.aten.reshape.default(convert_element_type, [-1, 64]); convert_element_type = None
permute: "i8[64, 32][1, 64]cpu" = torch.ops.aten.permute.default(arg0_1, [1, 0]); arg0_1 = None
_int_mm: "i32[32, 32][32, 1]cpu" = torch.ops.aten._int_mm.default(view_5, permute); view_5 = permute = None
convert_element_type_1: "f32[32, 32][32, 1]cpu" = torch.ops.prims.convert_element_type.default(_int_mm, torch.float32); _int_mm = None
view_6: "f32[32, 1][1, 1]cpu" = torch.ops.aten.reshape.default(clamp_min, [-1, 1]); clamp_min = None
expand: "f32[32, 32][1, 0]cpu" = torch.ops.aten.expand.default(view_6, [32, 32]); view_6 = None
mul_2: "f32[32, 32][32, 1]cpu" = torch.ops.aten.mul.Tensor(convert_element_type_1, expand); convert_element_type_1 = expand = None
mul_3: "f32[32, 32][32, 1]cpu" = torch.ops.aten.mul.Tensor(mul_2, arg1_1); mul_2 = arg1_1 = None
add_1: "f32[32, 32][32, 1]cpu" = torch.ops.aten.add.Tensor(mul_3, arg3_1); mul_3 = arg3_1 = None
relu: "f32[32, 32][32, 1]cpu" = torch.ops.aten.relu.default(add_1); add_1 = None
amin_1: "f32[32][1]cpu" = torch.ops.aten.amin.default(relu, [1])
full_default_3: "f32[32][1]cpu" = torch.ops.aten.full.default([32], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
minimum_1: "f32[32][1]cpu" = torch.ops.aten.minimum.default(amin_1, full_default_3); amin_1 = full_default_3 = None
neg_1: "f32[32][1]cpu" = torch.ops.aten.neg.default(minimum_1); minimum_1 = None
amax_1: "f32[32][1]cpu" = torch.ops.aten.amax.default(relu, [1])
full_default_4: "f32[32][1]cpu" = torch.ops.aten.full.default([32], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
maximum_2: "f32[32][1]cpu" = torch.ops.aten.maximum.default(amax_1, full_default_4); amax_1 = full_default_4 = None
maximum_3: "f32[32][1]cpu" = torch.ops.aten.maximum.default(neg_1, maximum_2); neg_1 = maximum_2 = None
div_1: "f32[32][1]cpu" = torch.ops.aten.div.Tensor(maximum_3, 127.0); maximum_3 = None
clamp_min_2: "f32[32][1]cpu" = torch.ops.aten.clamp_min.default(div_1, 1e-05); div_1 = None
view_12: "f32[32, 1][1, 1]cpu" = torch.ops.aten.reshape.default(clamp_min_2, [32, 1])
reciprocal_1: "f32[32, 1][1, 1]cpu" = torch.ops.aten.reciprocal.default(view_12); view_12 = None
mul_4: "f32[32, 1][1, 1]cpu" = torch.ops.aten.mul.Tensor(reciprocal_1, 1.0); reciprocal_1 = None
mul_5: "f32[32, 32][32, 1]cpu" = torch.ops.aten.mul.Tensor(relu, mul_4); relu = mul_4 = None
round_2: "f32[32, 32][32, 1]cpu" = torch.ops.aten.round.default(mul_5); mul_5 = None
clamp_min_3: "f32[32, 32][32, 1]cpu" = torch.ops.aten.clamp_min.default(round_2, -127); round_2 = None
clamp_max_1: "f32[32, 32][32, 1]cpu" = torch.ops.aten.clamp_max.default(clamp_min_3, 127); clamp_min_3 = None
convert_element_type_2: "i8[32, 32][32, 1]cpu" = torch.ops.prims.convert_element_type.default(clamp_max_1, torch.int8); clamp_max_1 = None
view_15: "i8[32, 32][32, 1]cpu" = torch.ops.aten.reshape.default(convert_element_type_2, [-1, 32]); convert_element_type_2 = None
permute_1: "i8[32, 32][1, 32]cpu" = torch.ops.aten.permute.default(arg5_1, [1, 0]); arg5_1 = None
_int_mm_1: "i32[32, 32][32, 1]cpu" = torch.ops.aten._int_mm.default(view_15, permute_1); view_15 = permute_1 = None
convert_element_type_3: "f32[32, 32][32, 1]cpu" = torch.ops.prims.convert_element_type.default(_int_mm_1, torch.float32); _int_mm_1 = None
view_16: "f32[32, 1][1, 1]cpu" = torch.ops.aten.reshape.default(clamp_min_2, [-1, 1]); clamp_min_2 = None
expand_1: "f32[32, 32][1, 0]cpu" = torch.ops.aten.expand.default(view_16, [32, 32]); view_16 = None
mul_6: "f32[32, 32][32, 1]cpu" = torch.ops.aten.mul.Tensor(convert_element_type_3, expand_1); convert_element_type_3 = expand_1 = None
mul_7: "f32[32, 32][32, 1]cpu" = torch.ops.aten.mul.Tensor(mul_6, arg6_1); mul_6 = arg6_1 = None
add_3: "f32[32, 32][32, 1]cpu" = torch.ops.aten.add.Tensor(mul_7, arg8_1); mul_7 = arg8_1 = None
return (add_3,)
There was a problem hiding this comment.
Do you know why this API int8_dynamic_activation_int8_weight only insert reshape at input but not at output? Because this pattern looks strange.
There was a problem hiding this comment.
I feels this behavior is strange.
@leslie-fang-intel, yes, there's a redundant reshape in the torchao pattern but should we not simply match this pattern in stock PyTorch for now? Can you please elaborate on why the origin of the redundant reshape matters in the context of this PR?
It's possible that some change in torchao may eliminate the redundant reshape, and then the torchao int8_dynamic_activation_int8_weight GEMM pattern may change in the future, but why is it a blocker at this point, given that it's dependent on torchao behavior?
There was a problem hiding this comment.
does it work if we match the sub-graph without reshape on the activation assuming the activation is 2D?
There was a problem hiding this comment.
does it work if we match the sub-graph without reshape on the activation assuming the activation is 2D?
@jgong5, no, it does not because this pattern is specifically being added for torchao's int8_dynamic_activation_int8_weight linear pattern.
In fact, torchao's smooth-quant exhibits the same behavior. I have listed the torchao UTs in the PR description.
There was a problem hiding this comment.
@jgong5, please clarify if you'd like the pattern without reshape on activation to be supported (although it does not currently correspond to the torchao pattern being added, but that may change in the future). Thanks!
There was a problem hiding this comment.
@jgong5 @leslie-fang-intel, as you advised offline, I added support for the smaller pattern that doesn't contain reshape for activation. Thanks!
As you suspected, the redundant reshape was optimized out, and is not present in the codegened code:

Summary
Extends #139595 for Inductor pattern-matching pattern covered for torchao in this scenario (inference-only) -
The pattern that's matched is
torch._intmm-> convert to FP32/BF16 -> [optional expand for activation scale] ->mul->mul.In practice, it also matches the smooth-quant int8 quantized linear pattern if its output is not reshaped.
More details
oneDNN int8 matmul supports application of per-channel weight scale but not a vector activation scale, which could be applied as a post op, but is currently unsupported in ATen. Bias addition (which could be supported with an add post-op) is also unfused.
The fusion pattern used in this PR is
torch._intmm-> convert to FP32/BF16 ->mul, which will be replaced by oneDNN qlinear op.The speedup over eager-mode is due to 2 reasons -
But, in the future, the whole pattern (including application of activation scale, which would be a mul post-op) + bias could be fused if corresponding support would be enabled in ATen.
Verification
Added UT in this PR
Corresponding torchao UTs
int8 Smoothquant legacy API -
TORCHINDUCTOR_FREEZING=1 TORCH_COMPILE_DEBUG=1 TORCH_LOGS="+inductor" python test/integration/test_integration.py -v -k test_non_dynamically_quantizable_linear.The difference from [Inductor][CPU] Fuse SmoothQuant int8 linear pattern #139595 is that there are no reshapes of the linear output in this pattern.
int8 da8w8 - symmetrically quantized activation (dynamically) & statically quantized weights -
TORCH_COMPILE_DEBUG=1 TORCH_LOGS="+inductor" TORCHINDUCTOR_FREEZING=1 python test/integration/test_integration.py -v -k test_int8_dynamic_quant_subclass_api_0_cpughstack info
Stack from ghstack (oldest at bottom):
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov