@@ -1303,24 +1303,25 @@ def checkQuantized(model):
13031303 self .assertEqual (type (model .sub2 .conv ), nn .Conv2d )
13041304 self .assertEqual (type (model .sub2 .relu ), nn .ReLU )
13051305 test_only_eval_fn (model , self .img_data_1d )
1306- checkQuantized (model )
1306+ with self .assertRaisesRegex (RuntimeError , "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'" ):
1307+ checkQuantized (model )
13071308
13081309 model = ModelForFusion (default_qat_qconfig ).train ()
13091310 model = fuse_modules (model , [['conv1' , 'bn1' , 'relu1' ],
13101311 ['sub1.conv' , 'sub1.bn' ]])
13111312 model = quantize_qat (model , test_only_train_fn , self .img_data_1d )
1312- checkQuantized (model )
1313+ with self .assertRaisesRegex (RuntimeError , "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'" ):
1314+ checkQuantized (model )
13131315
13141316
13151317 def test_fuse_module_eval (self ):
13161318 model = ModelForFusion (default_qconfig )
13171319 model .eval ()
1318- model = fuse_modules (model , [['conv3' , 'relu4' ],
1320+ model = fuse_modules (model , [['conv3' , 'bn3' , ' relu4' ],
13191321 ['conv1' , 'bn1' , 'relu1' ],
13201322 ['conv2' , 'relu2' ],
13211323 ['bn2' , 'relu3' ],
13221324 ['sub1.conv' , 'sub1.bn' ]])
1323-
13241325 self .assertEqual (type (model .conv1 ), nni .ConvReLU2d ,
13251326 "Fused Conv + BN + Relu first layer (BN is folded)" )
13261327 self .assertEqual (type (model .conv1 [0 ]), nn .Conv2d ,
@@ -1345,11 +1346,13 @@ def test_fuse_module_eval(self):
13451346 "Fused Conv + BN + Relu second layer (Skipped Relu)" )
13461347
13471348 self .assertEqual (type (model .conv3 ), nni .ConvReLU1d ,
1348- "Fused Conv + Relu for conv1d " )
1349+ "Fused Conv + Relu for Conv1d (folded BN) " )
13491350 self .assertEqual (type (model .conv3 [0 ]), nn .Conv1d ,
1350- "Fused Conv + Relu for conv1d " )
1351+ "Fused Conv + Relu for Conv1d " )
13511352 self .assertEqual (type (model .conv3 [1 ]), nn .ReLU ,
1352- "Fused Conv + Relu for conv1d" )
1353+ "Fused Conv + Relu for Conv1d" )
1354+ self .assertEqual (type (model .bn3 ), nn .Identity ,
1355+ "Fused Conv + BN + Relu for Conv1d (Skipped BN)" )
13531356
13541357 self .assertEqual (type (model .sub1 .conv ), nn .Conv2d ,
13551358 "Fused submodule Conv + folded BN" )
@@ -1383,7 +1386,7 @@ def checkQuantized(model):
13831386 ['conv2' , 'relu2' ],
13841387 ['bn2' , 'relu3' ],
13851388 ['sub1.conv' , 'sub1.bn' ],
1386- ['conv3' , 'relu4' ]])
1389+ ['conv3' , 'bn3' , ' relu4' ]])
13871390 model = quantize (model , test_only_eval_fn , self .img_data_1d )
13881391 checkQuantized (model )
13891392
0 commit comments