Skip to content

Commit eff68bc

Browse files
jerryzh168facebook-github-bot
authored andcommitted
[quant][graphmode] quantization support for aten::add (#34572)
Summary: Pull Request resolved: #34572 Test Plan: python test/test_jit.py Imported from OSS Differential Revision: D20519607 fbshipit-source-id: c57e062cffc24a47a76b73b58aff7f9ef80183fa
1 parent b2dcedf commit eff68bc

3 files changed

Lines changed: 59 additions & 4 deletions

File tree

test/test_jit.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1187,8 +1187,8 @@ def forward(self, x, y, weight):
11871187
m = torch.jit.script(M()).eval()
11881188
qconfig_dict = {'' : script_qconfig(default_qconfig)}
11891189
m = prepare_script(m, qconfig_dict, False)
1190-
# 3 for x, y, weight, one for output of each F.conv2d
1191-
assert len(attrs_with_prefix(m, '_observer')) == 5
1190+
# 3 for x, y, weight, one for output of each F.conv2d and one for output of add
1191+
assert len(attrs_with_prefix(m, '_observer')) == 6
11921192

11931193
def test_insert_observers_shared_class_type(self):
11941194
class M(torch.nn.Module):
@@ -1404,7 +1404,7 @@ def forward(self, x, w0, w1, w2):
14041404

14051405
# we just check we have one dequant on every op input, even input
14061406
# is sharded as multi uses
1407-
FileCheck().check_count("aten::dequantize", 8, exactly=True) \
1407+
FileCheck().check_count("aten::dequantize", 9, exactly=True) \
14081408
.run(str(get_forward_graph(m._c)))
14091409

14101410
def test_insert_quant_dequant_shared_class_type(self):
@@ -1568,6 +1568,33 @@ def forward(self, x):
15681568
.check("quantized::conv2d_relu") \
15691569
.run(m.graph_for(data))
15701570

1571+
def test_quantized_add_fusion(self):
1572+
class Add(torch.nn.Module):
1573+
def __init__(self):
1574+
super(Add, self).__init__()
1575+
1576+
def forward(self, x, y):
1577+
return x + y
1578+
1579+
class InplaceAdd(torch.nn.Module):
1580+
def __init__(self):
1581+
super(InplaceAdd, self).__init__()
1582+
1583+
def forward(self, x, y):
1584+
x += y
1585+
return x
1586+
1587+
for M in [Add, InplaceAdd]:
1588+
m = torch.jit.script(M()).eval()
1589+
m = prepare_script(m, {'': script_qconfig(default_qconfig)}, True)
1590+
data = torch.randn(1, 1, 10, 10, dtype=torch.float)
1591+
m(data, data)
1592+
m = convert_script(m, True)
1593+
FileCheck().check_not("aten::add") \
1594+
.check_not("aten::add_") \
1595+
.check("quantized::add") \
1596+
.run(m.graph_for(data, data))
1597+
15711598
def test_quantized_add_relu_fusion(self):
15721599
class M(torch.nn.Module):
15731600
def __init__(self, inplace):

torch/csrc/jit/passes/quantization.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,8 @@ bool nodeQuantizable(Node* n) {
176176
"relu",
177177
"addmm",
178178
"matmul",
179-
"add_"
179+
"add_",
180+
"add",
180181
});
181182
}
182183

torch/csrc/jit/passes/quantization_patterns.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,31 @@ graph(%packed_params, %a_quant, %r_scale, %r_zero_point, %r_dtype):
127127
%r = quantized::linear(%a_quant, %packed_params, %r_scale, %r_zero_point)
128128
return (%r) )";
129129

130+
std::string add = R"(
131+
graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype):
132+
%a_dequant = aten::dequantize(%a_quant)
133+
%b_dequant = aten::dequantize(%b_quant)
134+
%r_add = aten::add(%a_dequant, %b_dequant, %alpha)
135+
%r = aten::quantize_per_tensor(%r_add, %scale, %zero_point, %dtype)
136+
return (%r) )";
137+
138+
// TODO: add %dtype after when https://github.com/pytorch/pytorch/issues/34351
139+
// is fixed
140+
// TODO: add filter for %alpha
141+
std::string quantized_add = R"(
142+
graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype):
143+
%r_add = quantized::add(%a_quant, %b_quant, %scale, %zero_point)
144+
return (%r_add) )";
145+
146+
std::string inplace_add = R"(
147+
graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype):
148+
%a_dequant = aten::dequantize(%a_quant)
149+
%b_dequant = aten::dequantize(%b_quant)
150+
%r_add = aten::add_(%a_dequant, %b_dequant, %alpha)
151+
%r = aten::quantize_per_tensor(%r_add, %scale, %zero_point, %dtype)
152+
return (%r) )";
153+
// We don't have quantized inplace add right now
154+
130155
return {
131156
{conv2d, quantized_conv2d},
132157
{conv2d_relu, quantized_conv2d_relu},
@@ -137,6 +162,8 @@ graph(%packed_params, %a_quant, %r_scale, %r_zero_point, %r_dtype):
137162
{aten_linear, quantized_aten_linear},
138163
{add_relu, quantized_add_relu},
139164
{add_inplace_relu, quantized_add_relu},
165+
{add, quantized_add},
166+
{inplace_add, quantized_add},
140167
};
141168

142169
}

0 commit comments

Comments
 (0)