Skip to content

Commit 6232481

Browse files
jerryzh168facebook-github-bot
authored andcommitted
[quant][graphmode] Add RemoveReduantDequantize pass (#38434)
Summary: Pull Request resolved: #38434 We insert dequantize for each use in order to produce quantization patterns that will later be fused, after that we should also remove extra dequantize node produced by this operation. Test Plan: Imported from OSS Differential Revision: D21597834 fbshipit-source-id: 18dfb2760bbb08932aa4e1d06f96cfc5fb37ed88
1 parent dd7eed5 commit 6232481

5 files changed

Lines changed: 57 additions & 5 deletions

File tree

test/quantization/test_quantize_script.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,10 @@
2727
from torch.testing._internal.common_quantization import test_only_eval_fn as _test_only_eval_fn
2828
from torch.testing._internal.common_quantized import override_qengines
2929

30+
from torch.testing._internal.common_quantization import QuantizationTestCase
31+
3032
from torch.testing import FileCheck
3133
from torch.testing._internal.jit_utils import attrs_with_prefix
32-
from torch.testing._internal.jit_utils import JitTestCase
3334
from torch.testing._internal.jit_utils import get_forward
3435
from torch.testing._internal.jit_utils import get_forward_graph
3536
from torch.testing._internal.jit_utils import get_module_method
@@ -40,7 +41,7 @@
4041
import itertools
4142
import unittest
4243

43-
class TestQuantizeScriptJitPasses(JitTestCase):
44+
class TestQuantizeScriptJitPasses(QuantizationTestCase):
4445
""" Test graph mode quantization passes used by quantize_script
4546
"""
4647
def test_foldbn_trivial(self):
@@ -1015,6 +1016,21 @@ def forward(self, x):
10151016
.check("aten::dequantize") \
10161017
.run(model.graph)
10171018

1019+
def test_finalize_no_extra_dequantize(self):
1020+
class M(torch.nn.Module):
1021+
def __init__(self):
1022+
super(M, self).__init__()
1023+
self.conv = torch.nn.Conv2d(3, 3, 3).float()
1024+
1025+
def forward(self, x):
1026+
x = self.conv(x)
1027+
return x.size(0) * x
1028+
1029+
model = torch.jit.script(M()).eval()
1030+
model = quantize_script(model, {'': default_qconfig}, _test_only_eval_fn, [self.img_data])
1031+
FileCheck().check_not("aten::dequantize(") \
1032+
.run(model.graph)
1033+
10181034
def test_module_list(self):
10191035
class SimpleLinearLayer(torch.nn.Module):
10201036
def __init__(self):
@@ -1096,7 +1112,7 @@ def forward(self, x):
10961112
.check_not("aten::mul") \
10971113
.run(m.graph)
10981114

1099-
class TestQuantizeScriptPTSQOps(JitTestCase):
1115+
class TestQuantizeScriptPTSQOps(QuantizationTestCase):
11001116
""" Test graph mode post training static quantization works
11011117
for individual ops end to end.
11021118
"""
@@ -1737,7 +1753,7 @@ def forward(self, x):
17371753
.check("aten::dequantize(") \
17381754
.run(m2.graph)
17391755

1740-
class TestQuantizeDynamicScript(JitTestCase):
1756+
class TestQuantizeDynamicScript(QuantizationTestCase):
17411757
def test_prepare_dynamic(self):
17421758
class M(torch.nn.Module):
17431759
def __init__(self):

torch/csrc/jit/passes/quantization/helper.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,9 @@ std::unordered_map<NodeKind, std::tuple<c10::QScheme, QParamVector>>
179179
AtenFuncArgs _observe_inputs_aten_func = {};
180180
CallFuncArgs _observe_inputs_call_func = {{"batch_norm", 1}};
181181

182+
// Aten functions for getting tensor information
183+
std::vector<std::string> _tensor_info_funcs = {"size"};
184+
182185
// Check if `use` is an aten function of name `func_name` and if value
183186
// `v` is the nth argument (if provided) of the function.
184187
bool matchAtenFuncToUse(
@@ -347,6 +350,10 @@ bool isSingleInputGeneralAtenFunction(Node* n) {
347350
isAtenFunc(n, fixed_qparams_aten_funcs);
348351
}
349352

353+
bool isTensorInfoNode(Node* n) {
354+
return isAtenFunc(n, _tensor_info_funcs);
355+
}
356+
350357
c10::optional<std::tuple<c10::QScheme, QParamVector>> getFixedQParams(Node* n) {
351358
static std::vector<NodeKind> fixed_qparam_funcs;
352359
std::transform(

torch/csrc/jit/passes/quantization/helper.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ TORCH_API bool isSingleInputGeneralCallFunction(Node* n);
3939

4040
TORCH_API bool isSingleInputGeneralAtenFunction(Node* n);
4141

42+
// Check if the node will produce the same result regardless of whether
43+
// the input tensor is quantized or not, example: aten::size
44+
TORCH_API bool isTensorInfoNode(Node* n);
45+
4246
TORCH_API c10::optional<std::tuple<c10::QScheme, QParamVector>> getFixedQParams(
4347
Node* n);
4448

torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,30 @@ void ReplicateChooseQParamsQuantDequant(std::shared_ptr<Graph>& graph) {
289289
}
290290
}
291291

292+
void RemoveRedundantDequantize(std::shared_ptr<Graph>& graph) {
293+
const std::string dequantize = R"(
294+
graph(%a_quant):
295+
%a_dequant = aten::dequantize(%a_quant)
296+
return (%a_dequant) )";
297+
const std::string dequantize_replacement = R"(
298+
graph(%a):
299+
return (%a) )";
300+
auto filter = [&](const Match& match,
301+
const std::unordered_map<std::string, Value*>& vmap) {
302+
const auto& match_vmap = match.values_map;
303+
auto dequant_node = match_vmap.at(vmap.at("a_dequant"))->node();
304+
Value* dequant_out = dequant_node->output();
305+
TORCH_CHECK(
306+
dequant_out->uses().size() == 1,
307+
"Expect dequant output to have single use");
308+
Node* user = dequant_out->uses()[0].user;
309+
return isTensorInfoNode(user);
310+
};
311+
SubgraphRewriter rewriter;
312+
rewriter.RegisterRewritePattern(dequantize, dequantize_replacement);
313+
rewriter.runOnGraph(graph, filter);
314+
}
315+
292316
void RemoveRedundantQuantizationOps(std::shared_ptr<Graph>& graph) {
293317
const std::string dynamic_quant_ops = R"(
294318
graph(%a, %reduce_range, %a_dtype):
@@ -812,6 +836,7 @@ void InsertQuantDeQuantHelper::propagateQuantizationOps(Module& module) {
812836
RemoveRedundantQuantizationOps(graph);
813837
ReplicateQuant(graph);
814838
ReplicateDeQuant(graph);
839+
RemoveRedundantDequantize(graph);
815840
PropagateQuantizationOps(graph);
816841
}
817842

torch/testing/_internal/common_quantization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def setUp(self):
137137
super().setUp()
138138
self.calib_data = [(torch.rand(2, 5, dtype=torch.float), torch.randint(0, 1, (2,), dtype=torch.long)) for _ in range(2)]
139139
self.train_data = [(torch.rand(2, 5, dtype=torch.float), torch.randint(0, 1, (2,), dtype=torch.long)) for _ in range(2)]
140-
self.img_data = [(torch.rand(2, 3, 10, 10, dtype=torch.float), torch.randint(0, 1, (2,), dtype=torch.long))
140+
self.img_data = [(torch.rand(1, 3, 10, 10, dtype=torch.float), torch.randint(0, 1, (1,), dtype=torch.long))
141141
for _ in range(2)]
142142
self.img_data_1d = [(torch.rand(2, 3, 10, dtype=torch.float), torch.randint(0, 1, (1,), dtype=torch.long))
143143
for _ in range(2)]

0 commit comments

Comments
 (0)