|
7 | 7 | #include "caffe2/core/flags.h" |
8 | 8 | #include "caffe2/core/tensor_int8.h" |
9 | 9 | #include "caffe2/operators/fc_inference.h" |
| 10 | +#include "caffe2/quantization/server/int8_gen_quant_params.h" |
10 | 11 | #include "caffe2/utils/cpuid.h" |
11 | 12 | #include "fbgemm_pack_matrix_cache.h" |
12 | 13 | #include "fbgemm_pack_op.h" |
@@ -871,17 +872,16 @@ bool FullyConnectedDNNLowPOp<T, ReluFused>::GetQuantizationParameters_() { |
871 | 872 | #endif |
872 | 873 |
|
873 | 874 | if (!dequantize_output_ && !requantization_param_selected_) { |
874 | | - CAFFE_ENFORCE(InputSize() == 3 || InputSize() == 5); |
875 | | - if (InputSize() == 5) { |
876 | | - CAFFE_ENFORCE(Input(3).template IsType<float>()); |
877 | | - CAFFE_ENFORCE(Input(4).template IsType<int>()); |
878 | | - |
879 | | - const auto& in_3 = Input(3); |
880 | | - CAFFE_ENFORCE_EQ(in_3.numel(), 1); |
881 | | - float in_scale = *(in_3.template data<float>()); |
882 | | - const auto& in_4 = Input(4); |
883 | | - CAFFE_ENFORCE_EQ(in_4.numel(), 1); |
884 | | - int in_zero_point = *(in_4.template data<int>()); |
| 875 | + CAFFE_ENFORCE(InputSize() <= 4); |
| 876 | + if (InputSize() == 4) { |
| 877 | + const auto* input_qparam_blob = |
| 878 | + this->template Input<caffe2::unique_ptr<caffe2::Int8QuantParamsBlob>>( |
| 879 | + 3) |
| 880 | + .get(); |
| 881 | + CAFFE_ENFORCE(input_qparam_blob); |
| 882 | + |
| 883 | + float in_scale = input_qparam_blob->qparam.scale; |
| 884 | + int in_zero_point = input_qparam_blob->qparam.zero_point; |
885 | 885 |
|
886 | 886 | dnnlowp::TensorQuantizationParams out_qparams_overwrite; |
887 | 887 | out_qparams_overwrite.scale = in_scale; |
@@ -964,7 +964,7 @@ REGISTER_CPU_OPERATOR_WITH_ENGINE( |
964 | 964 |
|
965 | 965 | using namespace std::placeholders; |
966 | 966 | OPERATOR_SCHEMA(Int8FCRelu) |
967 | | - .NumInputs(3, 5) |
| 967 | + .NumInputs(3, 4) |
968 | 968 | .NumOutputs(1) |
969 | 969 | .TensorInferenceFunction(std::bind(FCShapeInference, _1, _2, false)) |
970 | 970 | .CostInferenceFunction(std::bind(CostInferenceForFC, _1, _2, false)); |
|
0 commit comments