Skip to content

Commit 86397f6

Browse files
jerryzh168facebook-github-bot
authored andcommitted
[quant] Remove get_qparams in Observers (#38435)
Summary: Pull Request resolved: #38435 Test Plan: Imported from OSS Differential Revision: D21597835 Pulled By: jerryzh168 fbshipit-source-id: 88a8dd110db5586509bf98fa6712290f1756c272
1 parent d5461e7 commit 86397f6

2 files changed

Lines changed: 11 additions & 38 deletions

File tree

torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -339,17 +339,17 @@ void RemoveRedundantQuantizationOps(std::shared_ptr<Graph>& graph) {
339339
rewriter.runOnGraph(graph, filter);
340340
}
341341

342-
void checkGetQParamsResult(const IValue& qparams) {
342+
void checkCalculateQParamsResult(const IValue& qparams) {
343343
TORCH_CHECK(
344344
qparams.isTuple(),
345-
"`get_qparams` function is expected to return a "
345+
"`calculate_qparams` function is expected to return a "
346346
"Tuple, but got:",
347347
qparams.tagKind());
348348
auto tp = qparams.toTuple();
349349
TORCH_CHECK(
350-
tp->elements().size() == 2 || tp->elements().size() == 3,
351-
"`get_qparams` function is expected to return a "
352-
"Tuple of size 2 or 3, got Tuple of size ",
350+
tp->elements().size() == 2,
351+
"`calculate_qparams` function is expected to return a "
352+
"Tuple of size 2, got Tuple of size ",
353353
tp->elements().size());
354354
// Expect first two elements of the tuple to be Tensor
355355
for (size_t i = 0; i < 2; ++i) {
@@ -360,15 +360,6 @@ void checkGetQParamsResult(const IValue& qparams) {
360360
" has type: ",
361361
tp->elements()[i].tagKind());
362362
}
363-
// Expect the third elements of the tuple to be int
364-
if (tp->elements().size() == 3) {
365-
TORCH_CHECK(
366-
tp->elements()[2].isInt(),
367-
"Element of Tuple is expected to be int, but element ",
368-
2,
369-
" has type: ",
370-
tp->elements()[2].tagKind());
371-
}
372363
}
373364

374365
class InsertQuantDeQuantHelper {
@@ -567,9 +558,9 @@ std::tuple<c10::QScheme, QParamVector> InsertQuantDeQuantHelper::
567558
v->debugName(),
568559
" exists.");
569560
auto observer_module = module.attr(observer_name.value()).toModule();
570-
auto get_qparams = observer_module.get_method("get_qparams");
571-
IValue result = get_qparams(std::vector<IValue>());
572-
checkGetQParamsResult(result);
561+
auto calculate_qparams = observer_module.get_method("calculate_qparams");
562+
IValue result = calculate_qparams(std::vector<IValue>());
563+
checkCalculateQParamsResult(result);
573564
auto scalar_type = observer_module.attr("dtype");
574565
TORCH_CHECK(
575566
scalar_type.toScalarType() != at::ScalarType::Undefined,
@@ -582,9 +573,10 @@ std::tuple<c10::QScheme, QParamVector> InsertQuantDeQuantHelper::
582573
QParamVector qparams;
583574
auto qscheme = observer_module.attr("qscheme").toQScheme();
584575
if (isPerChannel(qscheme)) {
576+
auto axis = observer_module.attr("ch_axis");
585577
qparams.push_back(std::make_pair("_scale", scale));
586578
qparams.push_back(std::make_pair("_zero_point", zero_point));
587-
qparams.push_back(std::make_pair("_axis", tp->elements()[2].toInt()));
579+
qparams.push_back(std::make_pair("_axis", axis.toInt()));
588580
} else {
589581
qparams.push_back(std::make_pair("_scale", scale.item<double>()));
590582
qparams.push_back(

torch/quantization/observer.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,6 @@ def forward(self, x):
6666
def calculate_qparams(self, **kwargs):
6767
pass
6868

69-
# Returns all quantization parameters that's needed
70-
# for a quantize function call
71-
# For instance, per channel obsserver will return
72-
# scales, zero_points and axis
73-
@abstractmethod
74-
def get_qparams(self, **kwargs):
75-
pass
76-
7769
with_args = classmethod(_with_args)
7870

7971

@@ -194,10 +186,6 @@ def _calculate_qparams(self, min_val, max_val):
194186

195187
return scale, zero_point
196188

197-
@torch.jit.export
198-
def get_qparams(self):
199-
r"""Get all quantization parameters needed for quantize call"""
200-
return self.calculate_qparams()
201189

202190
class MinMaxObserver(_ObserverBase):
203191
r"""Observer module for computing the quantization parameters based on the
@@ -546,11 +534,6 @@ def _forward(self, x_orig):
546534
def calculate_qparams(self):
547535
return self._calculate_qparams(self.min_vals, self.max_vals)
548536

549-
@torch.jit.export
550-
def get_qparams(self):
551-
scales, zero_points = self.calculate_qparams()
552-
return scales, zero_points, self.ch_axis
553-
554537
def extra_repr(self):
555538
return "min_val={}, max_val={}".format(self.min_vals, self.max_vals)
556539

@@ -966,12 +949,10 @@ def __init__(self, dtype=torch.float16):
966949
def forward(self, x):
967950
return x
968951

952+
@torch.jit.export
969953
def calculate_qparams(self):
970954
raise Exception("calculate_qparams should not be called for NoopObserver")
971955

972-
def get_qparams(self):
973-
return self.calculate_qparams()
974-
975956

976957
# Restrict activations to be in the range (0,127)
977958
default_observer = MinMaxObserver.with_args(reduce_range=True)

0 commit comments

Comments
 (0)