@@ -653,6 +653,116 @@ def test_fuse_add_bias_into_conv_squeeze_4d_bias_no_fuse(self): # type: () -> N
653653 assert optimized_model .graph .node [0 ].op_type == 'Conv'
654654 assert optimized_model .graph .node [1 ].op_type == 'Add'
655655
656+ def test_fuse_matmul_add_bias_into_gemm (self ): # type: () -> None
657+ matmul = helper .make_node ("MatMul" , ["X" , "Y" ], ["Z" ])
658+ add = helper .make_node ("Add" , ["Z" , "B" ], ["A" ])
659+ graph = helper .make_graph (
660+ [matmul , add ],
661+ "test" ,
662+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , (32 , 10 )),
663+ helper .make_tensor_value_info ("Y" , TensorProto .FLOAT , (10 , 16 )),
664+ helper .make_tensor_value_info ("B" , TensorProto .FLOAT , (16 ,))],
665+ [helper .make_tensor_value_info ("A" , TensorProto .FLOAT , (32 , 16 ))]
666+ )
667+ optimized_model = self ._optimized (graph , ["fuse_matmul_add_bias_into_gemm" ])
668+
669+ assert len (list (optimized_model .graph .node )) == 1
670+ assert optimized_model .graph .node [0 ].op_type == "Gemm"
671+
672+ def test_fuse_matmul_add_bias_into_gemm_2d_bias (self ): # type: () -> None
673+ matmul = helper .make_node ("MatMul" , ["X" , "Y" ], ["Z" ])
674+ add = helper .make_node ("Add" , ["Z" , "B" ], ["A" ])
675+ graph = helper .make_graph (
676+ [matmul , add ],
677+ "test" ,
678+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , (32 , 10 )),
679+ helper .make_tensor_value_info ("Y" , TensorProto .FLOAT , (10 , 16 )),
680+ helper .make_tensor_value_info ("B" , TensorProto .FLOAT , (1 , 16 ))],
681+ [helper .make_tensor_value_info ("A" , TensorProto .FLOAT , (32 , 16 ))]
682+ )
683+ optimized_model = self ._optimized (graph , ["fuse_matmul_add_bias_into_gemm" ])
684+
685+ assert len (list (optimized_model .graph .node )) == 1
686+ assert optimized_model .graph .node [0 ].op_type == "Gemm"
687+
688+ def test_fuse_matmul_add_bias_into_gemm_2d_bias_same_shape (self ): # type: () -> None
689+ matmul = helper .make_node ("MatMul" , ["X" , "Y" ], ["Z" ])
690+ add = helper .make_node ("Add" , ["Z" , "B" ], ["A" ])
691+ graph = helper .make_graph (
692+ [matmul , add ],
693+ "test" ,
694+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , (32 , 10 )),
695+ helper .make_tensor_value_info ("Y" , TensorProto .FLOAT , (10 , 16 )),
696+ helper .make_tensor_value_info ("B" , TensorProto .FLOAT , (32 , 16 ))],
697+ [helper .make_tensor_value_info ("A" , TensorProto .FLOAT , (32 , 16 ))]
698+ )
699+ optimized_model = self ._optimized (graph , ["fuse_matmul_add_bias_into_gemm" ])
700+
701+ assert len (list (optimized_model .graph .node )) == 1
702+ assert optimized_model .graph .node [0 ].op_type == "Gemm"
703+
704+ def test_fuse_matmul_add_bias_into_gemm_2d_bias_bcast_no_fuse (self ): # type: () -> None
705+ matmul = helper .make_node ("MatMul" , ["X" , "Y" ], ["Z" ])
706+ add = helper .make_node ("Add" , ["Z" , "B" ], ["A" ])
707+ graph = helper .make_graph (
708+ [matmul , add ],
709+ "test" ,
710+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , (1 , 10 )),
711+ helper .make_tensor_value_info ("Y" , TensorProto .FLOAT , (10 , 16 )),
712+ helper .make_tensor_value_info ("B" , TensorProto .FLOAT , (16 , 16 ))],
713+ [helper .make_tensor_value_info ("A" , TensorProto .FLOAT , (16 , 16 ))]
714+ )
715+ optimized_model = self ._optimized (graph , ["fuse_matmul_add_bias_into_gemm" ])
716+
717+ assert optimized_model .graph == graph
718+
719+ def test_fuse_matmul_add_bias_into_gemm_3d_matmul_no_fuse (self ): # type: () -> None
720+ matmul = helper .make_node ("MatMul" , ["X" , "Y" ], ["Z" ])
721+ add = helper .make_node ("Add" , ["Z" , "B" ], ["A" ])
722+ graph = helper .make_graph (
723+ [matmul , add ],
724+ "test" ,
725+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , (2 , 3 , 4 )),
726+ helper .make_tensor_value_info ("Y" , TensorProto .FLOAT , (2 , 4 , 3 )),
727+ helper .make_tensor_value_info ("B" , TensorProto .FLOAT , (3 , 3 ))],
728+ [helper .make_tensor_value_info ("A" , TensorProto .FLOAT , (2 , 3 , 3 ))]
729+ )
730+ optimized_model = self ._optimized (graph , ["fuse_matmul_add_bias_into_gemm" ])
731+
732+ assert optimized_model .graph == graph
733+
734+ def test_fuse_matmul_add_bias_into_gemm_3d_bias_no_fuse (self ): # type: () -> None
735+ matmul = helper .make_node ("MatMul" , ["X" , "Y" ], ["Z" ])
736+ add = helper .make_node ("Add" , ["Z" , "B" ], ["A" ])
737+ graph = helper .make_graph (
738+ [matmul , add ],
739+ "test" ,
740+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , (32 , 10 )),
741+ helper .make_tensor_value_info ("Y" , TensorProto .FLOAT , (10 , 16 )),
742+ helper .make_tensor_value_info ("B" , TensorProto .FLOAT , (4 , 1 , 16 ))],
743+ [helper .make_tensor_value_info ("A" , TensorProto .FLOAT , (32 , 16 ))]
744+ )
745+ optimized_model = self ._optimized (graph , ["fuse_matmul_add_bias_into_gemm" ])
746+
747+ assert optimized_model .graph == graph
748+
749+ def test_fuse_matmul_add_bias_into_gemm_multiple_use_no_fuse (self ): # type: () -> None
750+ matmul = helper .make_node ("MatMul" , ["X" , "Y" ], ["Z" ])
751+ identity = helper .make_node ("Identity" , ["Z" ], ["A1" ])
752+ add = helper .make_node ("Add" , ["Z" , "B" ], ["A2" ])
753+ graph = helper .make_graph (
754+ [matmul , add , identity ],
755+ "test" ,
756+ [helper .make_tensor_value_info ("X" , TensorProto .FLOAT , (32 , 10 )),
757+ helper .make_tensor_value_info ("Y" , TensorProto .FLOAT , (10 , 16 )),
758+ helper .make_tensor_value_info ("B" , TensorProto .FLOAT , (1 , 16 ))],
759+ [helper .make_tensor_value_info ("A1" , TensorProto .FLOAT , (32 , 16 )),
760+ helper .make_tensor_value_info ("A2" , TensorProto .FLOAT , (32 , 16 ))]
761+ )
762+ optimized_model = self ._optimized (graph , ["fuse_matmul_add_bias_into_gemm" ])
763+
764+ assert optimized_model .graph == graph
765+
656766 def test_fuse_pad_into_conv (self ): # type: () -> None
657767 pad = helper .make_node (
658768 "Pad" ,
0 commit comments