Skip to content

Commit 9e32a1f

Browse files
suofacebook-github-bot
authored andcommitted
[wip] update graph fuser aliasdb in-place (#37106)
Summary: Pull Request resolved: #37106 Recomputing the aliasdb on every fusion iteration + in every subblock is hugely expensive. Instead, update it in-place when doing fusion. The graph fuser pass operates by pushing nodes into a fusion group. So we start with ``` x, y = f(a, b, c) ``` and end with: ``` x_out, y_out = prim::fusionGroup(a, b, c) x_in, y_in = f(a_in, b_in, c_in) -> x_in, y_in ``` We destroy the `x` and `y` `Value*`s in the process. This operation is easy to express as an update to the aliasDb--`x_out` just takes on all the aliasing information `x` used to have. In particular, since we know `f` and `prim::fusionGroup` are purely functional, we don't have to mess with any write information. This PR is the bare minimum to get this working, in the interest of unscrewing the compilation times ASAP. Followups I want to do: - We don't have a way of expressing deletion of values in AliasDb. In `graph_fuser.cpp` we sometimes construct nodes that we end up throwing away, and we are littering `MemoryDAG` with references to dangling pointers. Because of the way the pass works, it's fine, but this is fragile so I want to fix it. - We should decouple alias analysis from write tracking, to simplify the job of keeping the write caches consistent as we mutate the aliasing information. - the tensorexpr fuser doesn't do this and thus is incorrect today, we need to update it to work. Test Plan: Imported from OSS Differential Revision: D21219179 Pulled By: suo fbshipit-source-id: 8ae5397b3a0ad90edec2fbc555647091f1ad5284
1 parent 0692804 commit 9e32a1f

9 files changed

Lines changed: 202 additions & 54 deletions

File tree

test/cpp/jit/test_fuser.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,31 @@ void testFusion() {
176176
};
177177
}
178178

179+
void testFusionAliasing() {
180+
const auto graph_string = R"IR(
181+
graph(%0 : Tensor,
182+
%1 : Tensor):
183+
%12 : int = prim::Constant[value=1]()
184+
%2.1 : Tensor = aten::mul(%0, %1)
185+
%2 : Tensor = aten::mul(%2.1, %1)
186+
%3 : Tensor = aten::add_(%2, %1, %12)
187+
%4 : Tensor = aten::mul(%2, %1)
188+
%5 : Tensor = aten::add(%2, %4, %12)
189+
return (%5))IR";
190+
auto g = std::make_shared<Graph>();
191+
torch::jit::parseIR(graph_string, g.get());
192+
193+
g->lint();
194+
FuseGraph(g);
195+
196+
// We should not be able to fuse across the in-place operation here.
197+
testing::FileCheck()
198+
.check("prim::FusionGroup_0")
199+
->check("aten::add_")
200+
->check("prim::FusionGroup_1")
201+
->run(*g);
202+
}
203+
179204
void testRegisterFusionCachesKernel() {
180205
// Constructs two functionally equivalent graphs
181206
const auto graph0_string = R"IR(

test/cpp/jit/tests.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ namespace jit {
9090
_(LiteInterpreterParams) \
9191
_(LiteInterpreterSetState) \
9292
_(TorchbindIValueAPI) \
93-
_(LiteInterpreterDict)
93+
_(LiteInterpreterDict) \
94+
_(FusionAliasing)
9495

9596
#if defined(USE_CUDA)
9697
#define TH_FORALL_TESTS_CUDA(_) \

torch/csrc/jit/ir/alias_analysis.cpp

Lines changed: 65 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ MemoryLocations AliasDb::getReads(Node* n) const {
324324
}
325325

326326
std::string AliasDb::getElementName(const Element* e) const {
327-
if (e->value == nullptr) {
327+
if (e->values.empty()) {
328328
// not the most efficient way, but given the fact there are
329329
// not too many types and even fewer of them will end up in
330330
// wildcardIndex_, we should be fine with a linear search
@@ -336,7 +336,17 @@ std::string AliasDb::getElementName(const Element* e) const {
336336
}
337337
return "WILDCARD";
338338
} else {
339-
return e->value->debugName();
339+
std::ostringstream ss;
340+
if (e->values.size() == 1) {
341+
ss << "%" << (*e->values.begin())->debugName();
342+
return ss.str();
343+
}
344+
ss << "(";
345+
for (const Value* v : e->values) {
346+
ss << "%" << v->debugName() << ", ";
347+
}
348+
ss << ")";
349+
return ss.str();
340350
}
341351
}
342352

@@ -1036,10 +1046,8 @@ void AliasDb::mapAliases(at::ArrayRef<Value*> from, at::ArrayRef<Value*> to) {
10361046
// The asserts are to guard against unintentional use.
10371047
// FIXME refactor aliasdb construction to be more robust to mutation so this
10381048
// hack isn't necessary.
1039-
void AliasDb::unsafeGiveFreshAlias(const Value* value) {
1049+
void AliasDb::createValue(const Value* value) {
10401050
TORCH_INTERNAL_ASSERT(isMutableTypeInternal(value->type()));
1041-
TORCH_INTERNAL_ASSERT(value->type()->containedTypes().size() == 0);
1042-
TORCH_INTERNAL_ASSERT(!elementMap_.count(value));
10431051
auto new_elem = memoryDAG_->unsafeMakeFreshValue(value);
10441052
elementMap_[value] = new_elem;
10451053
}
@@ -1068,14 +1076,39 @@ Element* AliasDb::getOrCreateElement(const Value* value) {
10681076
return elementMap_.at(value);
10691077
}
10701078

1071-
void AliasDb::replaceMemoryLocation(Value* existing, Value* new_value) {
1079+
void AliasDb::replaceWithNewValue(Value* existing, Value* new_value) {
1080+
TORCH_INTERNAL_ASSERT(
1081+
*unshapedType(existing->type()) == *unshapedType(new_value->type()),
1082+
"Types must be strictly equal if you are replacing aliasing information. ",
1083+
"Got existing: '",
1084+
existing->type()->python_str(),
1085+
"', new_value: '",
1086+
new_value->type()->python_str(),
1087+
"'");
10721088
if (!isMutableTypeInternal(existing)) {
10731089
return;
10741090
}
10751091
auto existing_elem = elementMap_.at(existing);
10761092
elementMap_[new_value] = existing_elem;
10771093
elementMap_.erase(existing);
1078-
existing_elem->value = new_value;
1094+
existing_elem->values = {new_value};
1095+
}
1096+
1097+
void AliasDb::copyValue(Value* from, Value* to) {
1098+
TORCH_INTERNAL_ASSERT(
1099+
*unshapedType(from->type()) == *unshapedType(to->type()),
1100+
"Types must be strictly equal if you are copying aliasing information. ",
1101+
"Got from: '",
1102+
from->type()->python_str(),
1103+
"', to: '",
1104+
to->type()->python_str(),
1105+
"'");
1106+
if (!isMutableTypeInternal(to)) {
1107+
return;
1108+
}
1109+
auto origElem = elementMap_.at(from);
1110+
elementMap_[to] = origElem;
1111+
origElem->values.insert(to);
10791112
}
10801113

10811114
bool AliasDb::moveAfterTopologicallyValid(Node* n, Node* movePoint) {
@@ -1492,5 +1525,30 @@ MemoryLocations AliasDb::buildWrittenToLocationsIndex() const {
14921525
return ret;
14931526
}
14941527

1528+
void Lint(const AliasDb* db) {
1529+
bool failed = false;
1530+
1531+
std::stringstream ss;
1532+
// Every mutable value in the system has a corresponding element.
1533+
for (const auto& v : db->graph_->all_values) {
1534+
if (!db->isMutableTypeInternal(v)) {
1535+
continue;
1536+
}
1537+
auto it = db->elementMap_.find(v);
1538+
if (it == db->elementMap_.end()) {
1539+
failed = true;
1540+
ss << "Value %" << v->debugName() << " of type "
1541+
<< v->type()->python_str() << " wasn't found in the element map.\n"
1542+
<< "It was defined in " << *v->node();
1543+
}
1544+
}
1545+
TORCH_INTERNAL_ASSERT(!failed, ss.str());
1546+
1547+
// Two checks that we want to add but can't until the mutation API is more
1548+
// fully developed.
1549+
// - Every mutable value in the aliasdb belongs to the graph
1550+
// - All container values have contained elements
1551+
}
1552+
14951553
} // namespace jit
14961554
} // namespace torch

torch/csrc/jit/ir/alias_analysis.h

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,24 @@ class AliasDb {
117117
static bool isMutableType(const Value* v);
118118
static bool isMutableType(const TypePtr& type);
119119

120+
/**
121+
* Mutation API
122+
*
123+
* These methods allow you to update AliasDb in-place if you are performing
124+
* graph mutation.
125+
*
126+
* WARNING: These methods should be considered INTERNAL. They do not perform
127+
* very many correctness checks, the user is responsible for making sure they
128+
* are updating AliasDb correctly. `Lint()`ing the AliasDb can help with
129+
* this.
130+
*/
131+
// Copy `existing`s aliasing info to `new_value`, and remove `existing`.
132+
void replaceWithNewValue(Value* existing, Value* new_value);
133+
// Copy `from`s aliasing info to `to`.
134+
void copyValue(Value* from, Value* to);
135+
// Create a new `value` that does not alias anything else.
136+
void createValue(const Value* value);
137+
120138
friend struct MutationRemover;
121139

122140
private:
@@ -187,17 +205,9 @@ class AliasDb {
187205
const Value* element,
188206
const Value* container);
189207
void mapAliases(at::ArrayRef<Value*> to, at::ArrayRef<Value*> from);
190-
void unsafeGiveFreshAlias(const Value* value);
191208
void giveFreshAlias(const Value* value);
192209
Element* getOrCreateElement(const Value* value);
193210

194-
// In the Value * -> Element * map replaces the mapping
195-
// of Value * existing -> Element * existing_elem with
196-
// Value * new_value -> Element * existing_elem
197-
// Callers are expected to maintain graph invariants & specify
198-
// own correctness conditions
199-
void replaceMemoryLocation(Value* existing, Value* new_value);
200-
201211
c10::optional<TypePtr> getMutableTypePtr(const TypePtr& type) const;
202212

203213
bool isContainerType(const TypePtr& type) const;
@@ -250,7 +260,14 @@ class AliasDb {
250260
std::unordered_set<const Value*> wildcards_;
251261

252262
std::string getElementName(const Element* e) const;
263+
264+
friend void Lint(const AliasDb* db);
253265
};
254266

267+
// Helper check that invariants over AliasDb are maintained.
268+
// Useful if you are using the AliasDb mutation API and want to check you did
269+
// the right thing.
270+
void Lint(const AliasDb* db);
271+
255272
} // namespace jit
256273
} // namespace torch

torch/csrc/jit/ir/ir.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ using pyobj_list = std::vector<THPObjectPtr>;
3232

3333
namespace torch {
3434
namespace jit {
35+
class AliasDb;
3536

3637
using ::c10::Argument;
3738
using ::c10::FunctionSchema;
@@ -1231,6 +1232,7 @@ struct Graph {
12311232
TORCH_API void remapTypes(const std::function<TypePtr(TypePtr)>& type_map);
12321233

12331234
private:
1235+
friend void Lint(const AliasDb* db);
12341236
TORCH_API void freeNode(Node* n);
12351237
TORCH_API void freeValue(Value* v);
12361238
TORCH_API void freeBlock(Block* b);

torch/csrc/jit/passes/create_functional_graphs.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -378,12 +378,12 @@ struct MutationRemover {
378378
// same aliasing relationships as the original x.
379379
// To avoid rebuilding the entire alias db, we can replace
380380
// the memory dag element of x with x0.
381-
aliasDb_->replaceMemoryLocation(mutated_value, new_node->output());
381+
aliasDb_->replaceWithNewValue(mutated_value, new_node->output());
382382

383383
// it is an invariant that all mutable types have an element in the memory
384384
// dag so we must regive x an alias db element. We have already verified
385385
// that the mutated value is a fresh alias with a single use.
386-
aliasDb_->unsafeGiveFreshAlias(mutated_value);
386+
aliasDb_->createValue(mutated_value);
387387

388388
// We must erase the destroyed node from the AliasDb lists of writes
389389
aliasDb_->writeIndex_->erase(node);

0 commit comments

Comments
 (0)