Skip to content

Add torchao int8 da8w8 sym act sym wgt linear pattern for CPU#141851

Closed
sanchitintel wants to merge 13 commits intogh/sanchitintel/2/basefrom
gh/sanchitintel/2/head
Closed

Add torchao int8 da8w8 sym act sym wgt linear pattern for CPU#141851
sanchitintel wants to merge 13 commits intogh/sanchitintel/2/basefrom
gh/sanchitintel/2/head

Conversation

@sanchitintel
Copy link
Collaborator

@sanchitintel sanchitintel commented Dec 2, 2024

Summary

Extends #139595 for Inductor pattern-matching pattern covered for torchao in this scenario (inference-only) -

  • int8 quantized (symmetrically) activation (per token quantized).
  • Statically (so, scales are also constant during inference) per-channel int8 quantized (symmetrically) weights.

The pattern that's matched is torch._intmm -> convert to FP32/BF16 -> [optional expand for activation scale] ->mul -> mul.

In practice, it also matches the smooth-quant int8 quantized linear pattern if its output is not reshaped.

More details

oneDNN int8 matmul supports application of per-channel weight scale but not a vector activation scale, which could be applied as a post op, but is currently unsupported in ATen. Bias addition (which could be supported with an add post-op) is also unfused.

The fusion pattern used in this PR is torch._intmm -> convert to FP32/BF16 ->mul, which will be replaced by oneDNN qlinear op.

The speedup over eager-mode is due to 2 reasons -

  1. fusion of int8xint8 -> int32 GEMM, conversion to FP32/BF16 & application of weight scale. (In case of BF16, many intermediate conversions are also avoided).
  2. weight is pre-packed & cached by Inductor, so a reorder is avoided at run-time.

But, in the future, the whole pattern (including application of activation scale, which would be a mul post-op) + bias could be fused if corresponding support would be enabled in ATen.

Verification

Added UT in this PR

python test/inductor/test_mkldnn_pattern_matcher.py -v -k test_da8w8_sym_act_sym_wgt_with_int_mm

Corresponding torchao UTs

  1. int8 Smoothquant legacy API - TORCHINDUCTOR_FREEZING=1 TORCH_COMPILE_DEBUG=1 TORCH_LOGS="+inductor" python test/integration/test_integration.py -v -k test_non_dynamically_quantizable_linear.
    The difference from [Inductor][CPU] Fuse SmoothQuant int8 linear pattern #139595 is that there are no reshapes of the linear output in this pattern.

  2. int8 da8w8 - symmetrically quantized activation (dynamically) & statically quantized weights - TORCH_COMPILE_DEBUG=1 TORCH_LOGS="+inductor" TORCHINDUCTOR_FREEZING=1 python test/integration/test_integration.py -v -k test_int8_dynamic_quant_subclass_api_0_cpu

ghstack info

Stack from ghstack (oldest at bottom):

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Dec 2, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/141851

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit b210432 with merge base 5c2584a (image):

UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

[ghstack-poisoned]
[ghstack-poisoned]
@sanchitintel sanchitintel changed the title Add torchao da8w8 sym act sym wgt pattern for CPU Add torchao int8 da8w8 sym act sym wgt pattern for CPU Dec 2, 2024
[ghstack-poisoned]
[ghstack-poisoned]
sanchitintel added a commit that referenced this pull request Dec 2, 2024
ghstack-source-id: d0317b8
Pull Request resolved: #141851
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@sanchitintel sanchitintel changed the title Add torchao int8 da8w8 sym act sym wgt pattern for CPU Add torchao int8 da8w8 sym act sym wgt linear pattern for CPU Dec 2, 2024
[ghstack-poisoned]
sanchitintel added a commit that referenced this pull request Dec 2, 2024
ghstack-source-id: e579714
Pull Request resolved: #141851
@sanchitintel sanchitintel requested a review from jgong5 December 2, 2024 23:45
@sanchitintel
Copy link
Collaborator Author

Hi @leslie-fang-intel, can you please review this PR? Thanks!

),
KeywordArg("b"),
aten.reshape.default,
KeywordArg("a"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I know why this pattern must have a reshape node at activation? It seems we assume the input is 3D.

Copy link
Collaborator Author

@sanchitintel sanchitintel Dec 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not 3D in torchao, at least for da8w8.

convert_element_type_2: "i8[32, 32][32, 1]cpu" = torch.ops.prims.convert_element_type.default(clamp_max_1, torch.int8);  clamp_max_1 = None
view_15: "i8[32, 32][32, 1]cpu" = torch.ops.aten.reshape.default(convert_element_type_2, [-1, 32]);  convert_element_type_2 = None
permute_1: "i8[32, 32][1, 32]cpu" = torch.ops.aten.permute.default(arg5_1, [1, 0]);  arg5_1 = None
_int_mm_1: "i32[32, 32][32, 1]cpu" = torch.ops.aten._int_mm.default(view_15, permute_1);  view_15 = permute_1 = None

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then why we register the pattern with reshape at input of activation but without reshape at out?

Copy link
Collaborator Author

@sanchitintel sanchitintel Dec 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's just how the torchao pattern is with simple int8_dynamic_activation_int8_weight UTs I tested - they don't have reshape at out.

def forward(self, arg0_1: "i8[32, 64][64, 1]cpu", arg1_1: "f32[32][1]cpu", arg2_1: "i64[32][1]cpu", arg3_1: "f32[32][1]cpu", arg4_1: "f32[32, 64][64, 1]cpu", arg5_1: "i8[32, 32][32, 1]cpu", arg6_1: "f32[32][1]cpu", arg7_1: "i64[32][1]cpu", arg8_1: "f32[32][1]cpu"):
     # File: /localdisk/sanchitj/smoothquant/torch/_dynamo/external_utils.py:31 in inner, code: return fn(*args, **kwargs)
    full_2: "f32[32][1]cpu" = torch.ops.aten.full.default([32], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False);  full_2 = None
    full_default_2: "f32[32, 1][1, 1]cpu" = torch.ops.aten.full.default([32, 1], 0.0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False);  full_default_2 = None
    full_5: "f32[32][1]cpu" = torch.ops.aten.full.default([32], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False);  full_5 = None
    full_default_5: "f32[32, 1][1, 1]cpu" = torch.ops.aten.full.default([32, 1], 0.0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False);  full_default_5 = None
    amin: "f32[32][1]cpu" = torch.ops.aten.amin.default(arg4_1, [1])
    full_default: "f32[32][1]cpu" = torch.ops.aten.full.default([32], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    minimum: "f32[32][1]cpu" = torch.ops.aten.minimum.default(amin, full_default);  amin = full_default = None
    neg: "f32[32][1]cpu" = torch.ops.aten.neg.default(minimum);  minimum = None
    amax: "f32[32][1]cpu" = torch.ops.aten.amax.default(arg4_1, [1])
    full_default_1: "f32[32][1]cpu" = torch.ops.aten.full.default([32], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    maximum: "f32[32][1]cpu" = torch.ops.aten.maximum.default(amax, full_default_1);  amax = full_default_1 = None
    maximum_1: "f32[32][1]cpu" = torch.ops.aten.maximum.default(neg, maximum);  neg = maximum = None
    div: "f32[32][1]cpu" = torch.ops.aten.div.Tensor(maximum_1, 127.0);  maximum_1 = None
    clamp_min: "f32[32][1]cpu" = torch.ops.aten.clamp_min.default(div, 1e-05);  div = None
    view_2: "f32[32, 1][1, 1]cpu" = torch.ops.aten.reshape.default(clamp_min, [32, 1])
    reciprocal: "f32[32, 1][1, 1]cpu" = torch.ops.aten.reciprocal.default(view_2);  view_2 = None
    mul: "f32[32, 1][1, 1]cpu" = torch.ops.aten.mul.Tensor(reciprocal, 1.0);  reciprocal = None
    mul_1: "f32[32, 64][64, 1]cpu" = torch.ops.aten.mul.Tensor(arg4_1, mul);  arg4_1 = mul = None
    round_1: "f32[32, 64][64, 1]cpu" = torch.ops.aten.round.default(mul_1);  mul_1 = None
    clamp_min_1: "f32[32, 64][64, 1]cpu" = torch.ops.aten.clamp_min.default(round_1, -127);  round_1 = None
    clamp_max: "f32[32, 64][64, 1]cpu" = torch.ops.aten.clamp_max.default(clamp_min_1, 127);  clamp_min_1 = None
    convert_element_type: "i8[32, 64][64, 1]cpu" = torch.ops.prims.convert_element_type.default(clamp_max, torch.int8);  clamp_max = None
    view_5: "i8[32, 64][64, 1]cpu" = torch.ops.aten.reshape.default(convert_element_type, [-1, 64]);  convert_element_type = None
    permute: "i8[64, 32][1, 64]cpu" = torch.ops.aten.permute.default(arg0_1, [1, 0]);  arg0_1 = None
    _int_mm: "i32[32, 32][32, 1]cpu" = torch.ops.aten._int_mm.default(view_5, permute);  view_5 = permute = None
    convert_element_type_1: "f32[32, 32][32, 1]cpu" = torch.ops.prims.convert_element_type.default(_int_mm, torch.float32);  _int_mm = None
    view_6: "f32[32, 1][1, 1]cpu" = torch.ops.aten.reshape.default(clamp_min, [-1, 1]);  clamp_min = None
    expand: "f32[32, 32][1, 0]cpu" = torch.ops.aten.expand.default(view_6, [32, 32]);  view_6 = None
    mul_2: "f32[32, 32][32, 1]cpu" = torch.ops.aten.mul.Tensor(convert_element_type_1, expand);  convert_element_type_1 = expand = None
    mul_3: "f32[32, 32][32, 1]cpu" = torch.ops.aten.mul.Tensor(mul_2, arg1_1);  mul_2 = arg1_1 = None
    add_1: "f32[32, 32][32, 1]cpu" = torch.ops.aten.add.Tensor(mul_3, arg3_1);  mul_3 = arg3_1 = None
    relu: "f32[32, 32][32, 1]cpu" = torch.ops.aten.relu.default(add_1);  add_1 = None
    amin_1: "f32[32][1]cpu" = torch.ops.aten.amin.default(relu, [1])
    full_default_3: "f32[32][1]cpu" = torch.ops.aten.full.default([32], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    minimum_1: "f32[32][1]cpu" = torch.ops.aten.minimum.default(amin_1, full_default_3);  amin_1 = full_default_3 = None
    neg_1: "f32[32][1]cpu" = torch.ops.aten.neg.default(minimum_1);  minimum_1 = None
    amax_1: "f32[32][1]cpu" = torch.ops.aten.amax.default(relu, [1])
    full_default_4: "f32[32][1]cpu" = torch.ops.aten.full.default([32], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    maximum_2: "f32[32][1]cpu" = torch.ops.aten.maximum.default(amax_1, full_default_4);  amax_1 = full_default_4 = None
    maximum_3: "f32[32][1]cpu" = torch.ops.aten.maximum.default(neg_1, maximum_2);  neg_1 = maximum_2 = None
    div_1: "f32[32][1]cpu" = torch.ops.aten.div.Tensor(maximum_3, 127.0);  maximum_3 = None
    clamp_min_2: "f32[32][1]cpu" = torch.ops.aten.clamp_min.default(div_1, 1e-05);  div_1 = None
    view_12: "f32[32, 1][1, 1]cpu" = torch.ops.aten.reshape.default(clamp_min_2, [32, 1])
    reciprocal_1: "f32[32, 1][1, 1]cpu" = torch.ops.aten.reciprocal.default(view_12);  view_12 = None
    mul_4: "f32[32, 1][1, 1]cpu" = torch.ops.aten.mul.Tensor(reciprocal_1, 1.0);  reciprocal_1 = None
    mul_5: "f32[32, 32][32, 1]cpu" = torch.ops.aten.mul.Tensor(relu, mul_4);  relu = mul_4 = None
    round_2: "f32[32, 32][32, 1]cpu" = torch.ops.aten.round.default(mul_5);  mul_5 = None
    clamp_min_3: "f32[32, 32][32, 1]cpu" = torch.ops.aten.clamp_min.default(round_2, -127);  round_2 = None
    clamp_max_1: "f32[32, 32][32, 1]cpu" = torch.ops.aten.clamp_max.default(clamp_min_3, 127);  clamp_min_3 = None
    convert_element_type_2: "i8[32, 32][32, 1]cpu" = torch.ops.prims.convert_element_type.default(clamp_max_1, torch.int8);  clamp_max_1 = None
    view_15: "i8[32, 32][32, 1]cpu" = torch.ops.aten.reshape.default(convert_element_type_2, [-1, 32]);  convert_element_type_2 = None
    permute_1: "i8[32, 32][1, 32]cpu" = torch.ops.aten.permute.default(arg5_1, [1, 0]);  arg5_1 = None
    _int_mm_1: "i32[32, 32][32, 1]cpu" = torch.ops.aten._int_mm.default(view_15, permute_1);  view_15 = permute_1 = None
    convert_element_type_3: "f32[32, 32][32, 1]cpu" = torch.ops.prims.convert_element_type.default(_int_mm_1, torch.float32);  _int_mm_1 = None
    view_16: "f32[32, 1][1, 1]cpu" = torch.ops.aten.reshape.default(clamp_min_2, [-1, 1]);  clamp_min_2 = None
    expand_1: "f32[32, 32][1, 0]cpu" = torch.ops.aten.expand.default(view_16, [32, 32]);  view_16 = None
    mul_6: "f32[32, 32][32, 1]cpu" = torch.ops.aten.mul.Tensor(convert_element_type_3, expand_1);  convert_element_type_3 = expand_1 = None
    mul_7: "f32[32, 32][32, 1]cpu" = torch.ops.aten.mul.Tensor(mul_6, arg6_1);  mul_6 = arg6_1 = None
    add_3: "f32[32, 32][32, 1]cpu" = torch.ops.aten.add.Tensor(mul_7, arg8_1);  mul_7 = arg8_1 = None
    return (add_3,)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know why this API int8_dynamic_activation_int8_weight only insert reshape at input but not at output? Because this pattern looks strange.

Copy link
Collaborator Author

@sanchitintel sanchitintel Dec 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feels this behavior is strange.

@leslie-fang-intel, yes, there's a redundant reshape in the torchao pattern but should we not simply match this pattern in stock PyTorch for now? Can you please elaborate on why the origin of the redundant reshape matters in the context of this PR?
It's possible that some change in torchao may eliminate the redundant reshape, and then the torchao int8_dynamic_activation_int8_weight GEMM pattern may change in the future, but why is it a blocker at this point, given that it's dependent on torchao behavior?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it work if we match the sub-graph without reshape on the activation assuming the activation is 2D?

Copy link
Collaborator Author

@sanchitintel sanchitintel Dec 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it work if we match the sub-graph without reshape on the activation assuming the activation is 2D?

@jgong5, no, it does not because this pattern is specifically being added for torchao's int8_dynamic_activation_int8_weight linear pattern.

In fact, torchao's smooth-quant exhibits the same behavior. I have listed the torchao UTs in the PR description.

Copy link
Collaborator Author

@sanchitintel sanchitintel Dec 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jgong5, please clarify if you'd like the pattern without reshape on activation to be supported (although it does not currently correspond to the torchao pattern being added, but that may change in the future). Thanks!

Copy link
Collaborator Author

@sanchitintel sanchitintel Dec 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jgong5 @leslie-fang-intel, as you advised offline, I added support for the smaller pattern that doesn't contain reshape for activation. Thanks!

As you suspected, the redundant reshape was optimized out, and is not present in the codegened code:
image

[ghstack-poisoned]
[ghstack-poisoned]
sanchitintel added a commit that referenced this pull request Dec 3, 2024
ghstack-source-id: 18d7bdf
Pull Request resolved: #141851
[ghstack-poisoned]
[ghstack-poisoned]
sanchitintel added a commit that referenced this pull request Dec 3, 2024
ghstack-source-id: 6c2ae4b
Pull Request resolved: #141851
@sanchitintel
Copy link
Collaborator Author

sanchitintel commented Dec 3, 2024

Closing in favor of #142015, which uses the latest base commit of #139595.

@github-actions github-actions bot deleted the gh/sanchitintel/2/head branch January 3, 2025 02:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants