Skip to content

Commit 1cbe274

Browse files
authored
fix the optimizer (#1510)
1 parent 481ad99 commit 1cbe274

1 file changed

Lines changed: 17 additions & 0 deletions

File tree

onnx/optimizer/passes/lift_lexical_references.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ namespace ONNX_NAMESPACE { namespace optimization {
6161
// unresolved_references.insert(input)
6262
// if node is a control flow operator:
6363
// for each sub-graph g:
64+
// for each output in g's body:
65+
// if output is defined in current scope:
66+
// control_inputs.insert(output)
6467
// refs = liftreferences(g)
6568
// for each ref in refs:
6669
// if ref is in this frame or any parent frame (control_inputs):
@@ -151,14 +154,28 @@ struct LiftLexicalReferences : public OptimizePass {
151154
}
152155

153156
std::set<std::string> local_unresolved;
157+
158+
//if a graph body output has already already been emitted outside of the
159+
//subgraph scope, then it must be added as an input to the subgraph
160+
auto add_subgraph_outputs = [&](Graph * body_graph) {
161+
for (auto *out: body_graph->outputs()) {
162+
if (environment_stack->findInAnyFrame(out->uniqueName())) {
163+
local_unresolved.insert(out->uniqueName());
164+
}
165+
}
166+
};
167+
154168
if (n->kind() == ONNX_NAMESPACE::kLoop) {
155169
auto *body_graph = n->g(ONNX_NAMESPACE::kbody).get();
156170
local_unresolved = liftReferences(body_graph);
171+
add_subgraph_outputs(body_graph);
157172
} else if (n->kind() == ONNX_NAMESPACE::kIf) {
158173
auto *then_graph = n->g(ONNX_NAMESPACE::kthen_branch).get();
174+
add_subgraph_outputs(then_graph);
159175
auto then_unresolved = liftReferences(then_graph);
160176
local_unresolved.insert(then_unresolved.begin(), then_unresolved.end());
161177
auto *else_graph = n->g(ONNX_NAMESPACE::kelse_branch).get();
178+
add_subgraph_outputs(else_graph);
162179
auto else_unresolved = liftReferences(else_graph);
163180
local_unresolved.insert(else_unresolved.begin(), else_unresolved.end());
164181
}

0 commit comments

Comments
 (0)