@@ -743,7 +743,7 @@ TEST(OptimizeConcatTest, UseVariadicCatWithListMutationAfterCat) {
743743 %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
744744 %10 : int = prim::Constant[value=0]()
745745 %input : Tensor[] = prim::ListConstruct(%0, %1)
746- %concat : Float(256 , 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10)
746+ %concat : Float(96 , 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10)
747747 %11 : Tensor = aten::append(%input, %2)
748748 return (%concat, %input)
749749 )IR" ;
@@ -789,7 +789,7 @@ TEST(OptimizeConcatTest, UseVariadicCatWithListMutationBeforeCat) {
789789 %10 : int = prim::Constant[value=0]()
790790 %input : Tensor[] = prim::ListConstruct(%0, %1)
791791 %11 : Tensor = aten::append(%input, %2)
792- %concat : Float(256 , 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10)
792+ %concat : Float(128 , 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10)
793793 return (%concat)
794794 )IR" ;
795795 parseIR (input, graph.get ());
@@ -799,17 +799,99 @@ TEST(OptimizeConcatTest, UseVariadicCatWithListMutationBeforeCat) {
799799 at::rand ({32 , 56 , 56 }, at::kCPU )};
800800 auto orig_outputs = runGraph (graph, inputs);
801801
802- ASSERT_FALSE (UseVariadicCat (graph));
802+ {
803+ ASSERT_FALSE (UseVariadicCat (graph));
804+ graph->lint ();
805+ auto opt_outputs = runGraph (graph, inputs);
806+ checkOutputs (orig_outputs, opt_outputs);
807+
808+ // No transformation should have happened since the `prim::ListConstruct` is
809+ // mutated before `aten::cat`.
810+ testing::FileCheck ()
811+ .check_count (" = prim::ListConstruct(" , 1 , /* exactly*/ true )
812+ ->check_count (" = aten::cat(" , 1 , /* exactly*/ true )
813+ ->check_count (" = prim::Concat(" , 0 , /* exactly*/ true )
814+ ->run (*graph);
815+ }
816+
817+ {
818+ ASSERT_TRUE (RemoveListMutationAndUseVariadicCat (graph));
819+ graph->lint ();
820+ auto opt_outputs = runGraph (graph, inputs);
821+ checkOutputs (orig_outputs, opt_outputs);
822+
823+ // The mutation of the list must be removed and the `aten::cat` op must
824+ // be replaced with the `prim::Concat` op in the graph. The transformed
825+ // graph should look like the following:
826+ //
827+ // graph(%0 : ...,
828+ // %1 : ...,
829+ // %2 : ...):
830+ // %3 : int = prim:Constant[value=0]()
831+ // %7 : Tensor = prim::Concat(%0, %1, %2, %3)
832+ // return (%7)
833+ testing::FileCheck ()
834+ .check_count (" = prim::Concat(" , 1 , /* exactly*/ true )
835+ ->check_count (" = prim::ListConstruct(" , 0 , /* exactly*/ true )
836+ ->check_count (" = aten::cat(" , 0 , /* exactly*/ true )
837+ ->run (*graph);
838+ }
839+ }
840+
841+ TEST (OptimizeConcatTest, UseVariadicCatWithMultipleListMutations) {
842+ auto graph = std::make_shared<Graph>();
843+
844+ const std::string input =
845+ R"IR(
846+ graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
847+ %1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
848+ %2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
849+ %3: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
850+ %4: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
851+ %10 : int = prim::Constant[value=0]()
852+ %input : Tensor[] = prim::ListConstruct(%0, %1)
853+ %concat.1 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10)
854+ %11 : Tensor = aten::append(%input, %2)
855+ %concat.2 : Float(128, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10)
856+ %12 : Tensor = aten::append(%input, %3)
857+ %concat.3 : Float(160, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10)
858+ %13 : Tensor = aten::append(%input, %4)
859+ %concat.4 : Float(192, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10)
860+ return (%concat.1, %concat.2, %concat.3, %concat.4)
861+ )IR" ;
862+ parseIR (input, graph.get ());
863+ std::vector<at::Tensor> inputs = {
864+ at::rand ({64 , 56 , 56 }, at::kCPU ),
865+ at::rand ({32 , 56 , 56 }, at::kCPU ),
866+ at::rand ({32 , 56 , 56 }, at::kCPU ),
867+ at::rand ({32 , 56 , 56 }, at::kCPU ),
868+ at::rand ({32 , 56 , 56 }, at::kCPU )};
869+ auto orig_outputs = runGraph (graph, inputs);
870+
871+ ASSERT_TRUE (RemoveListMutationAndUseVariadicCat (graph));
803872 graph->lint ();
804873 auto opt_outputs = runGraph (graph, inputs);
805874 checkOutputs (orig_outputs, opt_outputs);
806875
807- // No transformation should have happened since the `prim::ListConstruct` is
808- // mutated before `aten::cat`.
876+ // All the mutations of the list must be removed and the `aten::cat` ops must
877+ // be replaced with `prim::Concat` ops in the graph. The transformed graph
878+ // should look like the following:
879+ //
880+ // graph(%0 : ...,
881+ // %1 : ...,
882+ // %2 : ...,
883+ // %3 : ...,
884+ // %4 : ...):
885+ // %10 : int = prim:Constant[value=0]()
886+ // %5 : Tensor = prim::Concat(%0, %1, %10)
887+ // %6 : Tensor = prim::Concat(%0, %1, %2, %10)
888+ // %7 : Tensor = prim::Concat(%0, %1, %2, %3, %10)
889+ // %8 : Tensor = prim::Concat(%0, %1, %2, %3, %4, %10)
890+ // return (%5, %6, %7, %8)
809891 testing::FileCheck ()
810- .check_count (" = prim::ListConstruct (" , 1 , /* exactly*/ true )
811- ->check_count (" = aten::cat (" , 1 , /* exactly*/ true )
812- ->check_count (" = prim::Concat (" , 0 , /* exactly*/ true )
892+ .check_count (" = prim::Concat (" , 4 , /* exactly*/ true )
893+ ->check_count (" = prim::ListConstruct (" , 0 , /* exactly*/ true )
894+ ->check_count (" = aten::cat (" , 0 , /* exactly*/ true )
813895 ->run (*graph);
814896}
815897
0 commit comments