@@ -1120,6 +1120,49 @@ def test_eliminate_nop_dropout(self): # type: () -> None
11201120 assert len (optimized_model .graph .node ) == 1
11211121 assert optimized_model .graph .node [0 ].op_type == "Log"
11221122
1123+ def test_fuse_reduction_unsqueeze (self ): # type: () -> None
1124+ def _calculate_post_transform_shape (input_shape , reduction_axes , unsqueeze_axes , keepdim ): # type: (Tuple[int, ...], List[int], List[int], bool) -> Tuple[int, ...]
1125+ post_reduce_shape = None
1126+ if keepdim :
1127+ post_reduce_shape = tuple ([(x if i not in reduction_axes else 1 ) for i , x in enumerate (input_shape )])
1128+ else :
1129+ post_reduce_shape = tuple ([x for i , x in enumerate (input_shape ) if i not in reduction_axes ])
1130+ post_unsqueeze_shape = list (post_reduce_shape )
1131+ for ax in unsqueeze_axes :
1132+ post_unsqueeze_shape .insert (ax , 1 )
1133+ return tuple (post_unsqueeze_shape )
1134+
1135+ for reduction in ["ReduceL1" , "ReduceL2" , "ReduceLogSum" ,
1136+ "ReduceLogSumExp" , "ReduceMax" , "ReduceMean" ,
1137+ "ReduceMin" , "ReduceProd" , "ReduceSum" , "ReduceSumSquare" ]:
1138+ for axes1 in [[1 ], [1 , 2 ], [2 ]]:
1139+ for axes2 in [[1 ], [1 , 2 ], [2 ]]:
1140+ for keepdim in [False , True ]:
1141+ input_shape = (5 , 7 , 9 )
1142+ output_shape = _calculate_post_transform_shape (input_shape , axes1 , axes2 , keepdim ) # type: Tuple[int, ...]
1143+ node = helper .make_node (reduction , ["X" ], ["Y" ], axes = axes1 , keepdims = keepdim )
1144+ node1 = helper .make_node ("Unsqueeze" , ["Y" ], ["Z" ], axes = axes2 )
1145+ graph = helper .make_graph (
1146+ [node , node1 ],
1147+ "test" ,
1148+ [helper .make_tensor_value_info (
1149+ "X" , TensorProto .FLOAT , input_shape )],
1150+ [helper .make_tensor_value_info ("Z" , TensorProto .FLOAT , output_shape )])
1151+ optimized_model = self ._optimized (
1152+ graph , ["fuse_consecutive_reduce_unsqueeze" ], False )
1153+
1154+ if keepdim or axes1 != axes2 :
1155+ assert optimized_model .graph == graph
1156+ else :
1157+ assert len (optimized_model .graph .output ) == 1
1158+ assert len (optimized_model .graph .node ) == 1
1159+ assert optimized_model .graph .output [0 ].type .tensor_type .elem_type == TensorProto .FLOAT
1160+ assert optimized_model .graph .node [- 1 ].op_type == reduction
1161+ assert optimized_model .graph .node [- 1 ].attribute [0 ].name == "axes"
1162+ assert optimized_model .graph .node [- 1 ].attribute [0 ].ints == axes1
1163+ optimized_output_shape = tuple (x .dim_value for x in optimized_model .graph .output [0 ].type .tensor_type .shape .dim )
1164+ assert optimized_output_shape == output_shape
1165+
11231166
11241167if __name__ == '__main__' :
11251168 unittest .main ()
0 commit comments