Skip to content

Commit 53668f8

Browse files
navahgarfacebook-github-bot
authored andcommitted
[jit] Added an API to remove list mutations and replace with variadic cat until fixed point (#60776)
Summary: Pull Request resolved: #60776 Test Plan: Imported from OSS Reviewed By: eellison Differential Revision: D29406099 Pulled By: navahgar fbshipit-source-id: e2e69eb6ebff3bc6e25d80f46ce118e52f557fb6
1 parent 0cfcf68 commit 53668f8

3 files changed

Lines changed: 105 additions & 8 deletions

File tree

test/cpp/jit/test_concat_opt.cpp

Lines changed: 90 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -743,7 +743,7 @@ TEST(OptimizeConcatTest, UseVariadicCatWithListMutationAfterCat) {
743743
%2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
744744
%10 : int = prim::Constant[value=0]()
745745
%input : Tensor[] = prim::ListConstruct(%0, %1)
746-
%concat : Float(256, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10)
746+
%concat : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10)
747747
%11 : Tensor = aten::append(%input, %2)
748748
return (%concat, %input)
749749
)IR";
@@ -789,7 +789,7 @@ TEST(OptimizeConcatTest, UseVariadicCatWithListMutationBeforeCat) {
789789
%10 : int = prim::Constant[value=0]()
790790
%input : Tensor[] = prim::ListConstruct(%0, %1)
791791
%11 : Tensor = aten::append(%input, %2)
792-
%concat : Float(256, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10)
792+
%concat : Float(128, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10)
793793
return (%concat)
794794
)IR";
795795
parseIR(input, graph.get());
@@ -799,17 +799,99 @@ TEST(OptimizeConcatTest, UseVariadicCatWithListMutationBeforeCat) {
799799
at::rand({32, 56, 56}, at::kCPU)};
800800
auto orig_outputs = runGraph(graph, inputs);
801801

802-
ASSERT_FALSE(UseVariadicCat(graph));
802+
{
803+
ASSERT_FALSE(UseVariadicCat(graph));
804+
graph->lint();
805+
auto opt_outputs = runGraph(graph, inputs);
806+
checkOutputs(orig_outputs, opt_outputs);
807+
808+
// No transformation should have happened since the `prim::ListConstruct` is
809+
// mutated before `aten::cat`.
810+
testing::FileCheck()
811+
.check_count("= prim::ListConstruct(", 1, /*exactly*/ true)
812+
->check_count("= aten::cat(", 1, /*exactly*/ true)
813+
->check_count("= prim::Concat(", 0, /*exactly*/ true)
814+
->run(*graph);
815+
}
816+
817+
{
818+
ASSERT_TRUE(RemoveListMutationAndUseVariadicCat(graph));
819+
graph->lint();
820+
auto opt_outputs = runGraph(graph, inputs);
821+
checkOutputs(orig_outputs, opt_outputs);
822+
823+
// The mutation of the list must be removed and the `aten::cat` op must
824+
// be replaced with the `prim::Concat` op in the graph. The transformed
825+
// graph should look like the following:
826+
//
827+
// graph(%0 : ...,
828+
// %1 : ...,
829+
// %2 : ...):
830+
// %3 : int = prim:Constant[value=0]()
831+
// %7 : Tensor = prim::Concat(%0, %1, %2, %3)
832+
// return (%7)
833+
testing::FileCheck()
834+
.check_count("= prim::Concat(", 1, /*exactly*/ true)
835+
->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
836+
->check_count("= aten::cat(", 0, /*exactly*/ true)
837+
->run(*graph);
838+
}
839+
}
840+
841+
TEST(OptimizeConcatTest, UseVariadicCatWithMultipleListMutations) {
842+
auto graph = std::make_shared<Graph>();
843+
844+
const std::string input =
845+
R"IR(
846+
graph(%0: Float(64, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
847+
%1: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
848+
%2: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
849+
%3: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu),
850+
%4: Float(32, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu)):
851+
%10 : int = prim::Constant[value=0]()
852+
%input : Tensor[] = prim::ListConstruct(%0, %1)
853+
%concat.1 : Float(96, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10)
854+
%11 : Tensor = aten::append(%input, %2)
855+
%concat.2 : Float(128, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10)
856+
%12 : Tensor = aten::append(%input, %3)
857+
%concat.3 : Float(160, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10)
858+
%13 : Tensor = aten::append(%input, %4)
859+
%concat.4 : Float(192, 56, 56, strides=[3136, 56, 1], requires_grad=0, device=cpu) = aten::cat(%input, %10)
860+
return (%concat.1, %concat.2, %concat.3, %concat.4)
861+
)IR";
862+
parseIR(input, graph.get());
863+
std::vector<at::Tensor> inputs = {
864+
at::rand({64, 56, 56}, at::kCPU),
865+
at::rand({32, 56, 56}, at::kCPU),
866+
at::rand({32, 56, 56}, at::kCPU),
867+
at::rand({32, 56, 56}, at::kCPU),
868+
at::rand({32, 56, 56}, at::kCPU)};
869+
auto orig_outputs = runGraph(graph, inputs);
870+
871+
ASSERT_TRUE(RemoveListMutationAndUseVariadicCat(graph));
803872
graph->lint();
804873
auto opt_outputs = runGraph(graph, inputs);
805874
checkOutputs(orig_outputs, opt_outputs);
806875

807-
// No transformation should have happened since the `prim::ListConstruct` is
808-
// mutated before `aten::cat`.
876+
// All the mutations of the list must be removed and the `aten::cat` ops must
877+
// be replaced with `prim::Concat` ops in the graph. The transformed graph
878+
// should look like the following:
879+
//
880+
// graph(%0 : ...,
881+
// %1 : ...,
882+
// %2 : ...,
883+
// %3 : ...,
884+
// %4 : ...):
885+
// %10 : int = prim:Constant[value=0]()
886+
// %5 : Tensor = prim::Concat(%0, %1, %10)
887+
// %6 : Tensor = prim::Concat(%0, %1, %2, %10)
888+
// %7 : Tensor = prim::Concat(%0, %1, %2, %3, %10)
889+
// %8 : Tensor = prim::Concat(%0, %1, %2, %3, %4, %10)
890+
// return (%5, %6, %7, %8)
809891
testing::FileCheck()
810-
.check_count("= prim::ListConstruct(", 1, /*exactly*/ true)
811-
->check_count("= aten::cat(", 1, /*exactly*/ true)
812-
->check_count("= prim::Concat(", 0, /*exactly*/ true)
892+
.check_count("= prim::Concat(", 4, /*exactly*/ true)
893+
->check_count("= prim::ListConstruct(", 0, /*exactly*/ true)
894+
->check_count("= aten::cat(", 0, /*exactly*/ true)
813895
->run(*graph);
814896
}
815897

torch/csrc/jit/passes/concat_opt.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <torch/csrc/jit/jit_log.h>
1111
#include <torch/csrc/jit/passes/constant_pooling.h>
1212
#include <torch/csrc/jit/passes/dead_code_elimination.h>
13+
#include <torch/csrc/jit/passes/remove_mutation.h>
1314

1415
namespace torch {
1516
namespace jit {
@@ -590,5 +591,16 @@ bool UseVariadicCat(const std::shared_ptr<Graph>& graph) {
590591
return changed;
591592
}
592593

594+
bool RemoveListMutationAndUseVariadicCat(const std::shared_ptr<Graph>& graph) {
595+
bool changed_in_last_iter = true;
596+
bool changed = false;
597+
while (changed_in_last_iter) {
598+
changed_in_last_iter = RemoveListMutation(graph);
599+
changed_in_last_iter = changed_in_last_iter || UseVariadicCat(graph);
600+
changed = changed || changed_in_last_iter;
601+
}
602+
return changed;
603+
}
604+
593605
} // namespace jit
594606
} // namespace torch

torch/csrc/jit/passes/concat_opt.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,8 @@ TORCH_API void OptimizeConcat(const std::shared_ptr<Graph>& graph);
2222
// Returns true if the graph is modified.
2323
TORCH_API bool UseVariadicCat(const std::shared_ptr<Graph>& graph);
2424

25+
TORCH_API bool RemoveListMutationAndUseVariadicCat(
26+
const std::shared_ptr<Graph>& graph);
27+
2528
} // namespace jit
2629
} // namespace torch

0 commit comments

Comments
 (0)