@@ -1417,29 +1417,28 @@ def forward(self, x):
14171417
14181418 # Note Fusion for functional Relu with inplace argument isn't currently supported in fusion patterns.
14191419 class BNFuncRelu (torch .nn .Module ):
1420- def __init__ (self , inplace ):
1420+ def __init__ (self ):
14211421 super (BNFuncRelu , self ).__init__ ()
14221422 self .bn = torch .nn .BatchNorm2d (3 ).to (torch .float )
14231423
14241424 def forward (self , x ):
14251425 return F .relu (self .bn (x ), False )
14261426
14271427 class BNFuncInplaceRelu (torch .nn .Module ):
1428- def __init__ (self , inplace ):
1428+ def __init__ (self ):
14291429 super (BNFuncInplaceRelu , self ).__init__ ()
14301430 self .bn = torch .nn .BatchNorm2d (3 ).to (torch .float )
14311431
14321432 def forward (self , x ):
14331433 return F .relu (self .bn (x ), True )
14341434
14351435 data = [(torch .rand ((1 , 3 , 10 , 10 ), dtype = torch .float ), torch .randint (0 , 1 , (1 ,), dtype = torch .long )) for _ in range (2 )]
1436- for Model in [BNRelu , BNFuncRelu , BNFuncInplaceRelu ]:
1437- for inplace in [True , False ]:
1438- model = self ._test_op_impl (Model (inplace ), data , "quantized::batch_norm2d_relu" )
1439- FileCheck ().check_not ("aten::batch_norm" ) \
1440- .check_not ("aten::relu" ) \
1441- .check_not ("aten::relu_" ) \
1442- .run (model .graph )
1436+ for instance in [BNRelu (True ), BNRelu (False ), BNFuncRelu (), BNFuncInplaceRelu ()]:
1437+ model = self ._test_op_impl (instance , data , "quantized::batch_norm2d_relu" )
1438+ FileCheck ().check_not ("aten::batch_norm" ) \
1439+ .check_not ("aten::relu" ) \
1440+ .check_not ("aten::relu_" ) \
1441+ .run (model .graph )
14431442
14441443
14451444 def test_swap_dequantize_all_ops (self ):
0 commit comments