Skip to content

Commit 530d48e

Browse files
supriyarfacebook-github-bot
authored andcommitted
[quant] Support for fused ConvBn1d and ConvBnRelu1d modules (#38452) (#38749)
Summary: Pull Request resolved: #38749 Test Plan: python test/test_quantization.py TestFused Differential Revision: D21654659 Pulled By: supriyar fbshipit-source-id: 301be24083e794f4e71ff1d6d842e1aaefa640f0
1 parent 7587188 commit 530d48e

8 files changed

Lines changed: 75 additions & 17 deletions

File tree

docs/source/quantization.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,9 @@ accuracy
208208
* ``torch.nn.intrinsic`` — float versions of the modules, can be swapped with
209209
quantized version 1 to 1:
210210

211+
* :class:`~torch.nn.intrinsic.ConvBn1d` — Conv1d + BatchNorm1d
211212
* :class:`~torch.nn.intrinsic.ConvBn2d` — Conv2d + BatchNorm
213+
* :class:`~torch.nn.intrinsic.ConvBnReLU1d` — Conv1d + BatchNorm1d + ReLU
212214
* :class:`~torch.nn.intrinsic.ConvBnReLU2d` — Conv2d + BatchNorm + ReLU
213215
* :class:`~torch.nn.intrinsic.ConvReLU1d` — Conv1d + ReLU
214216
* :class:`~torch.nn.intrinsic.ConvReLU2d` — Conv2d + ReLU
@@ -584,11 +586,21 @@ then quantized.
584586

585587
.. automodule:: torch.nn.intrinsic
586588

589+
ConvBn1d
590+
~~~~~~~~~~~~~~~
591+
.. autoclass:: ConvBn1d
592+
:members:
593+
587594
ConvBn2d
588595
~~~~~~~~~~~~~~~
589596
.. autoclass:: ConvBn2d
590597
:members:
591598

599+
ConvBnReLU1d
600+
~~~~~~~~~~~~~~~
601+
.. autoclass:: ConvBnReLU1d
602+
:members:
603+
592604
ConvBnReLU2d
593605
~~~~~~~~~~~~~~~
594606
.. autoclass:: ConvBnReLU2d

test/quantization/test_quantize.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1303,24 +1303,25 @@ def checkQuantized(model):
13031303
self.assertEqual(type(model.sub2.conv), nn.Conv2d)
13041304
self.assertEqual(type(model.sub2.relu), nn.ReLU)
13051305
test_only_eval_fn(model, self.img_data_1d)
1306-
checkQuantized(model)
1306+
with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"):
1307+
checkQuantized(model)
13071308

13081309
model = ModelForFusion(default_qat_qconfig).train()
13091310
model = fuse_modules(model, [['conv1', 'bn1', 'relu1'],
13101311
['sub1.conv', 'sub1.bn']])
13111312
model = quantize_qat(model, test_only_train_fn, self.img_data_1d)
1312-
checkQuantized(model)
1313+
with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"):
1314+
checkQuantized(model)
13131315

13141316

13151317
def test_fuse_module_eval(self):
13161318
model = ModelForFusion(default_qconfig)
13171319
model.eval()
1318-
model = fuse_modules(model, [['conv3', 'relu4'],
1320+
model = fuse_modules(model, [['conv3', 'bn3', 'relu4'],
13191321
['conv1', 'bn1', 'relu1'],
13201322
['conv2', 'relu2'],
13211323
['bn2', 'relu3'],
13221324
['sub1.conv', 'sub1.bn']])
1323-
13241325
self.assertEqual(type(model.conv1), nni.ConvReLU2d,
13251326
"Fused Conv + BN + Relu first layer (BN is folded)")
13261327
self.assertEqual(type(model.conv1[0]), nn.Conv2d,
@@ -1345,11 +1346,13 @@ def test_fuse_module_eval(self):
13451346
"Fused Conv + BN + Relu second layer (Skipped Relu)")
13461347

13471348
self.assertEqual(type(model.conv3), nni.ConvReLU1d,
1348-
"Fused Conv + Relu for conv1d")
1349+
"Fused Conv + Relu for Conv1d (folded BN)")
13491350
self.assertEqual(type(model.conv3[0]), nn.Conv1d,
1350-
"Fused Conv + Relu for conv1d ")
1351+
"Fused Conv + Relu for Conv1d ")
13511352
self.assertEqual(type(model.conv3[1]), nn.ReLU,
1352-
"Fused Conv + Relu for conv1d")
1353+
"Fused Conv + Relu for Conv1d")
1354+
self.assertEqual(type(model.bn3), nn.Identity,
1355+
"Fused Conv + BN + Relu for Conv1d (Skipped BN)")
13531356

13541357
self.assertEqual(type(model.sub1.conv), nn.Conv2d,
13551358
"Fused submodule Conv + folded BN")
@@ -1383,7 +1386,7 @@ def checkQuantized(model):
13831386
['conv2', 'relu2'],
13841387
['bn2', 'relu3'],
13851388
['sub1.conv', 'sub1.bn'],
1386-
['conv3', 'relu4']])
1389+
['conv3', 'bn3', 'relu4']])
13871390
model = quantize(model, test_only_eval_fn, self.img_data_1d)
13881391
checkQuantized(model)
13891392

torch/nn/intrinsic/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11

2+
from .modules import ConvBn1d
23
from .modules import ConvBn2d
34
from .modules import ConvBn3d
5+
from .modules import ConvBnReLU1d
46
from .modules import ConvBnReLU2d
57
from .modules import ConvBnReLU3d
68
from .modules import ConvReLU1d
@@ -11,9 +13,11 @@
1113
from .modules import BNReLU3d
1214

1315
__all__ = [
16+
'ConvBn1d',
1417
'ConvBn2d',
1518
'ConvBn3d',
1619
'ConvBnReLU2d',
20+
'ConvBnReLU1d',
1721
'ConvBnReLU3d',
1822
'ConvReLU1d',
1923
'ConvReLU2d',

torch/nn/intrinsic/modules/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11

2+
from .fused import ConvBn1d
23
from .fused import ConvBn2d
34
from .fused import ConvBn3d
5+
from .fused import ConvBnReLU1d
46
from .fused import ConvBnReLU2d
57
from .fused import ConvBnReLU3d
68
from .fused import ConvReLU1d
@@ -12,8 +14,10 @@
1214

1315

1416
__all__ = [
17+
'ConvBn1d',
1518
'ConvBn2d',
1619
'ConvBn3d',
20+
'ConvBnReLU1d',
1721
'ConvBnReLU2d',
1822
'ConvBnReLU3d',
1923
'ConvReLU1d',

torch/nn/intrinsic/modules/fused.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import absolute_import, division, print_function, unicode_literals
22
import torch
3-
from torch.nn import Conv1d, Conv2d, Conv3d, ReLU, Linear, BatchNorm2d, BatchNorm3d
3+
from torch.nn import Conv1d, Conv2d, Conv3d, ReLU, Linear, BatchNorm1d, BatchNorm2d, BatchNorm3d
44

55
class ConvReLU1d(torch.nn.Sequential):
66
r"""This is a sequential container which calls the Conv 1d and ReLU modules.
@@ -38,6 +38,15 @@ def __init__(self, linear, relu):
3838
type(linear), type(relu))
3939
super(LinearReLU, self).__init__(linear, relu)
4040

41+
class ConvBn1d(torch.nn.Sequential):
42+
r"""This is a sequential container which calls the Conv 1d and Batch Norm 1d modules.
43+
During quantization this will be replaced with the corresponding fused module."""
44+
def __init__(self, conv, bn):
45+
assert type(conv) == Conv1d and type(bn) == BatchNorm1d, \
46+
'Incorrect types for input modules{}{}'.format(
47+
type(conv), type(bn))
48+
super(ConvBn1d, self).__init__(conv, bn)
49+
4150
class ConvBn2d(torch.nn.Sequential):
4251
r"""This is a sequential container which calls the Conv 2d and Batch Norm 2d modules.
4352
During quantization this will be replaced with the corresponding fused module."""
@@ -47,6 +56,15 @@ def __init__(self, conv, bn):
4756
type(conv), type(bn))
4857
super(ConvBn2d, self).__init__(conv, bn)
4958

59+
class ConvBnReLU1d(torch.nn.Sequential):
60+
r"""This is a sequential container which calls the Conv 1d, Batch Norm 1d, and ReLU modules.
61+
During quantization this will be replaced with the corresponding fused module."""
62+
def __init__(self, conv, bn, relu):
63+
assert type(conv) == Conv1d and type(bn) == BatchNorm1d and \
64+
type(relu) == ReLU, 'Incorrect types for input modules{}{}{}' \
65+
.format(type(conv), type(bn), type(relu))
66+
super(ConvBnReLU1d, self).__init__(conv, bn, relu)
67+
5068
class ConvBnReLU2d(torch.nn.Sequential):
5169
r"""This is a sequential container which calls the Conv 2d, Batch Norm 2d, and ReLU modules.
5270
During quantization this will be replaced with the corresponding fused module."""

torch/quantization/fuse_modules.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,19 +47,30 @@ def fuse_conv_bn_relu(conv, bn, relu):
4747
"""
4848
assert(conv.training == bn.training == relu.training),\
4949
"Conv and BN both must be in the same mode (train or eval)."
50-
is_3d = isinstance(conv, torch.nn.Conv3d)
5150
if conv.training:
51+
map_to_fused_module_train = {
52+
torch.nn.Conv2d: torch_fused.ConvBnReLU2d,
53+
torch.nn.Conv3d: torch_fused.ConvBnReLU3d,
54+
}
5255
assert bn.num_features == conv.out_channels, 'Output channel of Conv must match num_features of BatchNorm'
5356
assert bn.affine, 'Only support fusing BatchNorm with affine set to True'
5457
assert bn.track_running_stats, 'Only support fusing BatchNorm with tracking_running_stats set to True'
55-
56-
return torch_fused.ConvBnReLU3d(conv, bn, relu) if is_3d \
57-
else torch_fused.ConvBnReLU2d(conv, bn, relu)
58+
fused_module = map_to_fused_module_train.get(type(conv))
59+
if fused_module is not None:
60+
return fused_module(conv, bn, relu)
61+
else:
62+
raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn, relu)))
5863
else:
59-
return torch_fused.ConvReLU3d(
60-
torch.nn.utils.fusion.fuse_conv_bn_eval(conv, bn), relu) if is_3d \
61-
else torch_fused.ConvReLU2d(
62-
torch.nn.utils.fusion.fuse_conv_bn_eval(conv, bn), relu)
64+
map_to_fused_module_eval = {
65+
torch.nn.Conv1d: torch_fused.ConvReLU1d,
66+
torch.nn.Conv2d: torch_fused.ConvReLU2d,
67+
torch.nn.Conv3d: torch_fused.ConvReLU3d,
68+
}
69+
fused_module = map_to_fused_module_eval[type(conv)]
70+
if fused_module is not None:
71+
return fused_module(torch.nn.utils.fusion.fuse_conv_bn_eval(conv, bn), relu)
72+
else:
73+
raise NotImplementedError("Cannot fuse eval modules: {}".format((conv, bn, relu)))
6374

6475
# Generalization of getattr
6576
def _get_module(model, submodule_key):
@@ -93,6 +104,8 @@ def fuse_known_modules(mod_list):
93104
"""
94105

95106
OP_LIST_TO_FUSER_METHOD = {
107+
(torch.nn.Conv1d, torch.nn.BatchNorm1d): fuse_conv_bn,
108+
(torch.nn.Conv1d, torch.nn.BatchNorm1d, torch.nn.ReLU): fuse_conv_bn_relu,
96109
(torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn,
97110
(torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU): fuse_conv_bn_relu,
98111
(torch.nn.Conv3d, torch.nn.BatchNorm3d): fuse_conv_bn,

torch/quantization/quantize.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,9 @@ def convert(module, mapping=None, inplace=False):
321321
nni.LinearReLU,
322322
nni.BNReLU2d,
323323
nni.BNReLU3d,
324+
nni.ConvBn1d,
324325
nni.ConvReLU1d,
326+
nni.ConvBnReLU1d,
325327
nni.ConvReLU2d,
326328
nni.ConvReLU3d)
327329

torch/testing/_internal/common_quantization.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,7 @@ def __init__(self, qconfig):
632632
self.bn2 = nn.BatchNorm3d(2).to(dtype=torch.float)
633633
self.relu3 = nn.ReLU(inplace=True).to(dtype=torch.float)
634634
self.conv3 = nn.Conv1d(3, 3, 2).to(dtype=torch.float)
635+
self.bn3 = nn.BatchNorm1d(3).to(dtype=torch.float)
635636
self.relu4 = nn.ReLU(inplace=True).to(dtype=torch.float)
636637
# don't quantize sub2
637638
self.sub2.qconfig = None
@@ -641,6 +642,7 @@ def forward(self, x):
641642
x = x.squeeze(2)
642643
x = self.quant(x)
643644
x = self.conv3(x)
645+
x = self.bn3(x)
644646
x = self.relu4(x)
645647
x = x.unsqueeze(2)
646648
y = x.unsqueeze(2)

0 commit comments

Comments
 (0)