Skip to content

Fake quantization when the padding of convollution layer is asymmetric or pad_w != pad_h #11525

@fengyuentau

Description

@fengyuentau

Describe the bug
Consider a very simple model which has only one convolution layer with one input and one output. If the convolution layer has asymmetric padding or padding of pad_w != pad_h, using onnxruntime.quantization.quantiza_static generates fake-quantized model like the following:
image

Urgency
Probably none.

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 20.04.4 LTS
  • ONNX Runtime installed from (source or binary): binary
  • ONNX Runtime version: 1.11.1
  • Python version: 3.9.7
  • Visual Studio version (if applicable): None
  • GCC/Compiler version (if compiling from source): None
  • CUDA/cuDNN version: None
  • GPU model and memory: None

To Reproduce

 # conv with padding of pad_w != pad_h
 from torch.autograd import Variable
 import torch
 import torch.nn as nn
 import numpy as np
 import onnx # version 1.11.1
 import onnxruntime as rt
 from onnxruntime.quantization import quantize_static, CalibrationDataReader, QuantType

 input = Variable(torch.randn(1, 3, 10, 10))
 conv = nn.Conv2d(3, 5, kernel_size=3, stride=2, padding=(2, 1))
 torch.onnx.export(conv, input, 'conv.onnx', export_params=True, opset_version=12)

 class DataReader(CalibrationDataReader):
      def __init__(self, model_path, batchsize=5):
          sess = rt.InferenceSession(model_path, None)
          input_name = sess.get_inputs()[0].name
          input_shape = sess.get_inputs()[0].shape
          calib_data = np.random.uniform(-1, 1, size=[batchsize] + input_shape[1:]).astype("float32")
          self.enum_data_dicts = iter([{input_name: np.expand_dims(x, axis=0)} for x in calib_data])

      def get_next(self):
          return next(self.enum_data_dicts, None)

 dr = DataReader('conv.onnx')
 quantize_static('conv.onnx', 'conv-quant.onnx', dr, per_channel=False,
                      activation_type=QuantType.QInt8, weight_type=QuantType.QInt8)

Or you can replace 'conv.onnx' with the one of asymmetric padding to try quantizing asymmetric-padded conv layer.

Expected behavior
Quantized model (generated by neural compressor) of the structure like the following:
image

Metadata

Metadata

Assignees

No one assigned

    Labels

    quantizationissues related to quantization

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions