@@ -61,6 +61,9 @@ namespace ONNX_NAMESPACE { namespace optimization {
6161// unresolved_references.insert(input)
6262// if node is a control flow operator:
6363// for each sub-graph g:
64+ // for each output in g's body:
65+ // if output is defined in current scope:
66+ // control_inputs.insert(output)
6467// refs = liftreferences(g)
6568// for each ref in refs:
6669// if ref is in this frame or any parent frame (control_inputs):
@@ -151,14 +154,28 @@ struct LiftLexicalReferences : public OptimizePass {
151154 }
152155
153156 std::set<std::string> local_unresolved;
157+
158+ // if a graph body output has already already been emitted outside of the
159+ // subgraph scope, then it must be added as an input to the subgraph
160+ auto add_subgraph_outputs = [&](Graph * body_graph) {
161+ for (auto *out: body_graph->outputs ()) {
162+ if (environment_stack->findInAnyFrame (out->uniqueName ())) {
163+ local_unresolved.insert (out->uniqueName ());
164+ }
165+ }
166+ };
167+
154168 if (n->kind () == ONNX_NAMESPACE::kLoop ) {
155169 auto *body_graph = n->g (ONNX_NAMESPACE::kbody).get ();
156170 local_unresolved = liftReferences (body_graph);
171+ add_subgraph_outputs (body_graph);
157172 } else if (n->kind () == ONNX_NAMESPACE::kIf ) {
158173 auto *then_graph = n->g (ONNX_NAMESPACE::kthen_branch).get ();
174+ add_subgraph_outputs (then_graph);
159175 auto then_unresolved = liftReferences (then_graph);
160176 local_unresolved.insert (then_unresolved.begin (), then_unresolved.end ());
161177 auto *else_graph = n->g (ONNX_NAMESPACE::kelse_branch).get ();
178+ add_subgraph_outputs (else_graph);
162179 auto else_unresolved = liftReferences (else_graph);
163180 local_unresolved.insert (else_unresolved.begin (), else_unresolved.end ());
164181 }
0 commit comments