Skip to content

Commit 1ef77f9

Browse files
jerryzh168facebook-github-bot
authored andcommitted
[quant][graphmode] Different rule for handling aten::cat (#38570)
Summary: Pull Request resolved: #38570 We changed the rule of quantizing `aten::cat`, previously `aten::cat` is considered to be an op that should always be quantized, like `aten::conv2d`, but this is not ideal, a better way is to quantize the output of `aten::cat` depending on whether the input is quantized, if it is then we'll quantize the output, if not, then we will not quantize the output, since `aten::cat` works both on quantized and non-quantized tensor. Test Plan: Imported from OSS Differential Revision: D21600160 fbshipit-source-id: efa957e0eaa608fffefcdfefa7f442fab45605eb
1 parent dfbf9f3 commit 1ef77f9

4 files changed

Lines changed: 54 additions & 10 deletions

File tree

test/quantization/test_quantize_script.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1383,13 +1383,12 @@ def forward(self, x):
13831383
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
13841384
" with instruction set support avx2 or newer.")
13851385
def test_quantized_cat(self):
1386-
""" Note that we to support the case that torch.cat is quantized
1387-
indepdently, we need to have an observer that works
1388-
for list of Tensors.
1386+
""" quantization of the output of cat will be depend on the
1387+
input of cat. we only quantize the output of cat when its inputs are quantized.
13891388
"""
1390-
class M(torch.nn.Module):
1389+
class QuantizedCat(torch.nn.Module):
13911390
def __init__(self):
1392-
super(M, self).__init__()
1391+
super(QuantizedCat, self).__init__()
13931392
self.conv1 = torch.nn.Conv2d(1, 1, 1).float()
13941393
self.conv2 = torch.nn.Conv2d(1, 1, 1).float()
13951394

@@ -1398,7 +1397,15 @@ def forward(self, x, y):
13981397
y = self.conv2(y)
13991398
return torch.cat([x, y], 1)
14001399

1401-
m = torch.jit.script(M().eval())
1400+
class NonQuantizedCat(torch.nn.Module):
1401+
def __init__(self):
1402+
super(NonQuantizedCat, self).__init__()
1403+
1404+
def forward(self, x, y):
1405+
return torch.cat([x, y], 1)
1406+
1407+
# quantized cat
1408+
m = torch.jit.script(QuantizedCat()).eval()
14021409
m = prepare_script(m, {'': default_qconfig}, True)
14031410
# four for input and output of conv and one for output of cat
14041411
# this also tests the ListConstruct can preserve the observed property so that
@@ -1410,7 +1417,20 @@ def forward(self, x, y):
14101417

14111418
FileCheck().check_not("aten::cat") \
14121419
.check("quantized::cat") \
1413-
.run(m.graph_for(data, data))
1420+
.run(m.graph)
1421+
1422+
# non quantized cat
1423+
m = torch.jit.script(NonQuantizedCat()).eval()
1424+
m = prepare_script(m, {'': default_qconfig}, True)
1425+
assert len(attrs_with_prefix(m, '_observer_')) == 0
1426+
data = torch.randn(1, 1, 10, 10, dtype=torch.float)
1427+
m(data, data)
1428+
m = convert_script(m, True)
1429+
1430+
FileCheck().check_not("quantized::cat") \
1431+
.check("aten::cat") \
1432+
.run(m.graph)
1433+
14141434

14151435
def test_qbatch_norm(self):
14161436
class M(torch.nn.Module):

torch/csrc/jit/passes/quantization/helper.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ std::vector<std::string> _static_quantizable_aten_funcs = {
3232
"matmul",
3333
"add_",
3434
"add",
35-
"cat",
3635
"mul",
3736
"mul_",
3837
"hardswish",
@@ -181,6 +180,10 @@ CallFuncArgs _observe_inputs_call_func = {{"batch_norm", 1}};
181180
// Aten functions for getting tensor information
182181
std::vector<std::string> _tensor_info_funcs = {"size"};
183182

183+
// Aten functions whose output will be quantized or not quantized depending
184+
// on input tensor
185+
std::vector<std::string> _propagate_quant_ops = {"cat"};
186+
184187
// Check if `use` is an aten function of name `func_name` and if value
185188
// `v` is the nth argument (if provided) of the function.
186189
bool matchAtenFuncToUse(
@@ -350,6 +353,10 @@ bool isTensorInfoNode(Node* n) {
350353
return isAtenFunc(n, _tensor_info_funcs);
351354
}
352355

356+
bool isPropagateQuantNode(Node* n) {
357+
return isAtenFunc(n, _propagate_quant_ops);
358+
}
359+
353360
c10::optional<std::tuple<c10::QScheme, QParamVector>> getFixedQParams(Node* n) {
354361
static std::vector<NodeKind> fixed_qparam_funcs;
355362
std::transform(

torch/csrc/jit/passes/quantization/helper.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ TORCH_API bool isSingleInputGeneralAtenFunction(Node* n);
4343
// the input tensor is quantized or not, example: aten::size
4444
TORCH_API bool isTensorInfoNode(Node* n);
4545

46+
// Check if this is the node that we'll quantize or not quantize depending on
47+
// whether the input of the node is quantized, example: aten::cat
48+
TORCH_API bool isPropagateQuantNode(Node* n);
49+
4650
TORCH_API c10::optional<std::tuple<c10::QScheme, QParamVector>> getFixedQParams(
4751
Node* n);
4852

torch/csrc/jit/passes/quantization/insert_observers.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,12 @@ class InsertObserversHelper {
336336
Value* output,
337337
std::unordered_set<Value*>& block_observed_values);
338338

339+
bool shouldPropagateQuant(
340+
Node* n, const std::unordered_set<Value*>& block_observed_values) {
341+
return isObserved(n->input(0), block_observed_values);
342+
}
343+
344+
339345
void delayObservingValuesInPattern(Graph& graph, const PatternInfo& pattern);
340346

341347
void addValuesToDelayObservation(
@@ -732,7 +738,8 @@ bool InsertObserversHelper::valueNeedsToBeQuantized(Value* v) {
732738
// of the quantizable function.
733739
if (!is_dynamic_) {
734740
// Check whether producer is quantizable
735-
if (mayRequireObservation(v) && nodeQuantizable(v->node())) {
741+
if ((mayRequireObservation(v) && nodeQuantizable(v->node())) ||
742+
isPropagateQuantNode(v->node())) {
736743
return true;
737744
}
738745
}
@@ -1026,7 +1033,13 @@ InsertObserversHelper::insertObserversFor(
10261033
propagateObservedProperty(v, block_observed_values);
10271034
if (!inputs_outputs.count(v) &&
10281035
!isObserved(v, block_observed_values)) {
1029-
if (auto observer_opt = getObserverFor(v)) {
1036+
auto observer_opt = getObserverFor(v);
1037+
// If the node is one of the propagate quant node, e.g.
1038+
// aten::cat, we should observe its output only
1039+
// if the input of the node is observed
1040+
if (observer_opt &&
1041+
(!isPropagateQuantNode(n) ||
1042+
shouldPropagateQuant(n, block_observed_values))) {
10301043
recordObserved(
10311044
v, *observer_opt, values_to_observe, block_observed_values);
10321045
}

0 commit comments

Comments
 (0)