@@ -339,17 +339,17 @@ void RemoveRedundantQuantizationOps(std::shared_ptr<Graph>& graph) {
339339 rewriter.runOnGraph (graph, filter);
340340}
341341
342- void checkGetQParamsResult (const IValue& qparams) {
342+ void checkCalculateQParamsResult (const IValue& qparams) {
343343 TORCH_CHECK (
344344 qparams.isTuple (),
345- " `get_qparams ` function is expected to return a "
345+ " `calculate_qparams ` function is expected to return a "
346346 " Tuple, but got:" ,
347347 qparams.tagKind ());
348348 auto tp = qparams.toTuple ();
349349 TORCH_CHECK (
350- tp->elements ().size () == 2 || tp-> elements (). size () == 3 ,
351- " `get_qparams ` function is expected to return a "
352- " Tuple of size 2 or 3 , got Tuple of size " ,
350+ tp->elements ().size () == 2 ,
351+ " `calculate_qparams ` function is expected to return a "
352+ " Tuple of size 2, got Tuple of size " ,
353353 tp->elements ().size ());
354354 // Expect first two elements of the tuple to be Tensor
355355 for (size_t i = 0 ; i < 2 ; ++i) {
@@ -360,15 +360,6 @@ void checkGetQParamsResult(const IValue& qparams) {
360360 " has type: " ,
361361 tp->elements ()[i].tagKind ());
362362 }
363- // Expect the third elements of the tuple to be int
364- if (tp->elements ().size () == 3 ) {
365- TORCH_CHECK (
366- tp->elements ()[2 ].isInt (),
367- " Element of Tuple is expected to be int, but element " ,
368- 2 ,
369- " has type: " ,
370- tp->elements ()[2 ].tagKind ());
371- }
372363}
373364
374365class InsertQuantDeQuantHelper {
@@ -567,9 +558,9 @@ std::tuple<c10::QScheme, QParamVector> InsertQuantDeQuantHelper::
567558 v->debugName (),
568559 " exists." );
569560 auto observer_module = module .attr (observer_name.value ()).toModule ();
570- auto get_qparams = observer_module.get_method (" get_qparams " );
571- IValue result = get_qparams (std::vector<IValue>());
572- checkGetQParamsResult (result);
561+ auto calculate_qparams = observer_module.get_method (" calculate_qparams " );
562+ IValue result = calculate_qparams (std::vector<IValue>());
563+ checkCalculateQParamsResult (result);
573564 auto scalar_type = observer_module.attr (" dtype" );
574565 TORCH_CHECK (
575566 scalar_type.toScalarType () != at::ScalarType::Undefined,
@@ -582,9 +573,10 @@ std::tuple<c10::QScheme, QParamVector> InsertQuantDeQuantHelper::
582573 QParamVector qparams;
583574 auto qscheme = observer_module.attr (" qscheme" ).toQScheme ();
584575 if (isPerChannel (qscheme)) {
576+ auto axis = observer_module.attr (" ch_axis" );
585577 qparams.push_back (std::make_pair (" _scale" , scale));
586578 qparams.push_back (std::make_pair (" _zero_point" , zero_point));
587- qparams.push_back (std::make_pair (" _axis" , tp-> elements ()[ 2 ] .toInt ()));
579+ qparams.push_back (std::make_pair (" _axis" , axis .toInt ()));
588580 } else {
589581 qparams.push_back (std::make_pair (" _scale" , scale.item <double >()));
590582 qparams.push_back (
0 commit comments