Skip to content

Commit 753157b

Browse files
supriyarfacebook-github-bot
authored andcommitted
[quant][graph] Graph mode quantization support for sigmoid (#36622)
Summary: Pull Request resolved: #36622 Test Plan: python test/quantization/test_quantize_script.py test_swap_dequantize_all_ops Imported from OSS Differential Revision: D21075255 fbshipit-source-id: 025f432215eaa8acf34d492e7722102ca053abeb
1 parent 17c268b commit 753157b

2 files changed

Lines changed: 6 additions & 0 deletions

File tree

test/quantization/test_quantize_script.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1459,6 +1459,7 @@ def __init__(self):
14591459
self.avgpool2d = torch.nn.AvgPool2d(3)
14601460
self.avgpool3d = torch.nn.AvgPool3d(3)
14611461
self.conv = torch.nn.Conv2d(3, 3, 3)
1462+
self.sigmoid = torch.nn.Sigmoid()
14621463

14631464
def forward(self, x):
14641465
x = self.conv(x)
@@ -1481,7 +1482,9 @@ def forward(self, x):
14811482
x = torch.max(x)
14821483
x = torch.min(x)
14831484
x = torch.mean(x)
1485+
x = torch.sigmoid(x)
14841486
x = x.reshape([-1])
1487+
x = self.sigmoid(x)
14851488
x = x.view(-1)
14861489
x = x.transpose(1, 2)
14871490
x = x.contiguous()
@@ -1493,6 +1496,7 @@ def forward(self, x):
14931496
x = F.upsample(x, (32, 32))
14941497
x = F.upsample_bilinear(x, (32, 32))
14951498
x = F.upsample_nearest(x, (32, 32))
1499+
x = F.sigmoid(x)
14961500
x = x.permute(0, 2, 3, 1)
14971501
x = torch.repeat_interleave(x, 3, 1)
14981502
x = self.conv(x)

torch/csrc/jit/passes/quantization.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ std::vector<std::string> _single_input_general_call_funcs = {
8686
"upsample_bilinear",
8787
"upsample_nearest",
8888
"relu",
89+
"sigmoid",
8990
};
9091

9192
// Similar to prim::CallFunctions, there are aten ops that doesn't
@@ -121,6 +122,7 @@ std::vector<std::string> _single_input_general_aten_funcs = {
121122
"permute",
122123
"repeat_interleave",
123124
"relu",
125+
"sigmoid",
124126
};
125127

126128
struct FuncArg {

0 commit comments

Comments
 (0)