Skip to content

Commit b05127e

Browse files
committed
Update on "[quant][graph] Graph mode quantization support for sigmoid"
Summary: Test Plan: python test/quantization/test_quantize_script.py test_swap_dequantize_all_ops Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
2 parents 26a9ca8 + d28b1f9 commit b05127e

1 file changed

Lines changed: 8 additions & 9 deletions

File tree

test/quantization/test_quantize_script.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1417,29 +1417,28 @@ def forward(self, x):
14171417

14181418
# Note Fusion for functional Relu with inplace argument isn't currently supported in fusion patterns.
14191419
class BNFuncRelu(torch.nn.Module):
1420-
def __init__(self, inplace):
1420+
def __init__(self):
14211421
super(BNFuncRelu, self).__init__()
14221422
self.bn = torch.nn.BatchNorm2d(3).to(torch.float)
14231423

14241424
def forward(self, x):
14251425
return F.relu(self.bn(x), False)
14261426

14271427
class BNFuncInplaceRelu(torch.nn.Module):
1428-
def __init__(self, inplace):
1428+
def __init__(self):
14291429
super(BNFuncInplaceRelu, self).__init__()
14301430
self.bn = torch.nn.BatchNorm2d(3).to(torch.float)
14311431

14321432
def forward(self, x):
14331433
return F.relu(self.bn(x), True)
14341434

14351435
data = [(torch.rand((1, 3, 10, 10), dtype=torch.float), torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
1436-
for Model in [BNRelu, BNFuncRelu, BNFuncInplaceRelu]:
1437-
for inplace in [True, False]:
1438-
model = self._test_op_impl(Model(inplace), data, "quantized::batch_norm2d_relu")
1439-
FileCheck().check_not("aten::batch_norm") \
1440-
.check_not("aten::relu") \
1441-
.check_not("aten::relu_") \
1442-
.run(model.graph)
1436+
for instance in [BNRelu(True), BNRelu(False), BNFuncRelu(), BNFuncInplaceRelu()]:
1437+
model = self._test_op_impl(instance, data, "quantized::batch_norm2d_relu")
1438+
FileCheck().check_not("aten::batch_norm") \
1439+
.check_not("aten::relu") \
1440+
.check_not("aten::relu_") \
1441+
.run(model.graph)
14431442

14441443

14451444
def test_swap_dequantize_all_ops(self):

0 commit comments

Comments
 (0)