|
27 | 27 | from torch.testing._internal.common_quantization import test_only_eval_fn as _test_only_eval_fn |
28 | 28 | from torch.testing._internal.common_quantized import override_qengines |
29 | 29 |
|
| 30 | +from torch.testing._internal.common_quantization import QuantizationTestCase |
| 31 | + |
30 | 32 | from torch.testing import FileCheck |
31 | 33 | from torch.testing._internal.jit_utils import attrs_with_prefix |
32 | | -from torch.testing._internal.jit_utils import JitTestCase |
33 | 34 | from torch.testing._internal.jit_utils import get_forward |
34 | 35 | from torch.testing._internal.jit_utils import get_forward_graph |
35 | 36 | from torch.testing._internal.jit_utils import get_module_method |
|
40 | 41 | import itertools |
41 | 42 | import unittest |
42 | 43 |
|
43 | | -class TestQuantizeScriptJitPasses(JitTestCase): |
| 44 | +class TestQuantizeScriptJitPasses(QuantizationTestCase): |
44 | 45 | """ Test graph mode quantization passes used by quantize_script |
45 | 46 | """ |
46 | 47 | def test_foldbn_trivial(self): |
@@ -1015,6 +1016,21 @@ def forward(self, x): |
1015 | 1016 | .check("aten::dequantize") \ |
1016 | 1017 | .run(model.graph) |
1017 | 1018 |
|
| 1019 | + def test_finalize_no_extra_dequantize(self): |
| 1020 | + class M(torch.nn.Module): |
| 1021 | + def __init__(self): |
| 1022 | + super(M, self).__init__() |
| 1023 | + self.conv = torch.nn.Conv2d(3, 3, 3).float() |
| 1024 | + |
| 1025 | + def forward(self, x): |
| 1026 | + x = self.conv(x) |
| 1027 | + return x.size(0) * x |
| 1028 | + |
| 1029 | + model = torch.jit.script(M()).eval() |
| 1030 | + model = quantize_script(model, {'': default_qconfig}, _test_only_eval_fn, [self.img_data]) |
| 1031 | + FileCheck().check_not("aten::dequantize(") \ |
| 1032 | + .run(model.graph) |
| 1033 | + |
1018 | 1034 | def test_module_list(self): |
1019 | 1035 | class SimpleLinearLayer(torch.nn.Module): |
1020 | 1036 | def __init__(self): |
@@ -1096,7 +1112,7 @@ def forward(self, x): |
1096 | 1112 | .check_not("aten::mul") \ |
1097 | 1113 | .run(m.graph) |
1098 | 1114 |
|
1099 | | -class TestQuantizeScriptPTSQOps(JitTestCase): |
| 1115 | +class TestQuantizeScriptPTSQOps(QuantizationTestCase): |
1100 | 1116 | """ Test graph mode post training static quantization works |
1101 | 1117 | for individual ops end to end. |
1102 | 1118 | """ |
@@ -1737,7 +1753,7 @@ def forward(self, x): |
1737 | 1753 | .check("aten::dequantize(") \ |
1738 | 1754 | .run(m2.graph) |
1739 | 1755 |
|
1740 | | -class TestQuantizeDynamicScript(JitTestCase): |
| 1756 | +class TestQuantizeDynamicScript(QuantizationTestCase): |
1741 | 1757 | def test_prepare_dynamic(self): |
1742 | 1758 | class M(torch.nn.Module): |
1743 | 1759 | def __init__(self): |
|
0 commit comments