Skip to content

Commit ab6b942

Browse files
authored
Relax IF's shape inference rule (#2345)
* Relax If's shape inference rule * Make shape inference tests ok and move code to the right place * Add document changes * Update onnx/defs/controlflow/defs.cc * Update onnx/defs/controlflow/defs.cc * Address comments * Address comments * Address comments * Fix shape inference test * Disable a type check * Address a comment * Update defs.cc * Update Changelog.md * Update Operators.md
1 parent c5af774 commit ab6b942

8 files changed

Lines changed: 301 additions & 64 deletions

File tree

docs/Changelog.md

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11480,6 +11480,46 @@ This version of the operator has been available since version 11 of the default
1148011480
<dd>Constrain input and output types to float tensors.</dd>
1148111481
</dl>
1148211482

11483+
### <a name="If-11"></a>**If-11**</a>
11484+
11485+
If conditional
11486+
11487+
#### Version
11488+
11489+
This version of the operator has been available since version 11 of the default ONNX operator set.
11490+
11491+
#### Attributes
11492+
11493+
<dl>
11494+
<dt><tt>else_branch</tt> : graph (required)</dt>
11495+
<dd>Graph to run if condition is false. Has N outputs: values you wish to be live-out to the enclosing scope. The number of outputs must match the number of outputs in the then_branch.</dd>
11496+
<dt><tt>then_branch</tt> : graph (required)</dt>
11497+
<dd>Graph to run if condition is true. Has N outputs: values you wish to be live-out to the enclosing scope. The number of outputs must match the number of outputs in the else_branch.</dd>
11498+
</dl>
11499+
11500+
#### Inputs
11501+
11502+
<dl>
11503+
<dt><tt>cond</tt> : B</dt>
11504+
<dd>Condition for the if</dd>
11505+
</dl>
11506+
11507+
#### Outputs (1 - &#8734;)
11508+
11509+
<dl>
11510+
<dt><tt>outputs</tt> (variadic, heterogeneous) : V</dt>
11511+
<dd>Values that are live-out to the enclosing scope. The return values in the `then_branch` and `else_branch` must be of the same data type. The `then_branch` and `else_branch` may produce tensors with the same element type and different shapes. If corresponding outputs from the then-branch and the else-branch have static shapes S1 and S2, then the shape of the corresponding output variable of the if-node (if present) must be compatible with both S1 and S2 as it represents the union of both possible shapes.For example, if in a model file, the the first output of `then_branch` is typed float tensor with shape [2] and the first output of `else_branch` is another float tensor with shape [3], If's first output should have (a) no shape set, or (b) a shape of rank 1 with neither `dim_value` nor `dim_param` set, or (c) a shape of rank 1 with a unique `dim_param`. In contrast, the first output cannot have the shape [2] since [2] and [3] are not compatible.</dd>
11512+
</dl>
11513+
11514+
#### Type Constraints
11515+
11516+
<dl>
11517+
<dt><tt>V</tt> : tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(float16), tensor(float), tensor(double), tensor(string), tensor(bool), tensor(complex64), tensor(complex128)</dt>
11518+
<dd>All Tensor types</dd>
11519+
<dt><tt>B</tt> : tensor(bool)</dt>
11520+
<dd>Only bool</dd>
11521+
</dl>
11522+
1148311523
### <a name="LogSoftmax-11"></a>**LogSoftmax-11**</a>
1148411524

1148511525
The operator computes the logsoftmax (log of softmax) values for each layer in the batch

docs/Operators.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6415,7 +6415,9 @@ expect(node, inputs=[data], outputs=[data],
64156415

64166416
#### Version
64176417

6418-
This version of the operator has been available since version 1 of the default ONNX operator set.
6418+
This version of the operator has been available since version 11 of the default ONNX operator set.
6419+
6420+
Other versions of this operator: <a href="Changelog.md#If-1">If-1</a>
64196421

64206422
#### Attributes
64216423

@@ -6437,7 +6439,7 @@ This version of the operator has been available since version 1 of the default O
64376439

64386440
<dl>
64396441
<dt><tt>outputs</tt> (variadic, heterogeneous) : V</dt>
6440-
<dd>Values that are live-out to the enclosing scope. The return values in the `then_branch` and `else_branch` must be of the same shape and same data type.</dd>
6442+
<dd>Values that are live-out to the enclosing scope. The return values in the `then_branch` and `else_branch` must be of the same data type. The `then_branch` and `else_branch` may produce tensors with the same element type and different shapes. If corresponding outputs from the then-branch and the else-branch have static shapes S1 and S2, then the shape of the corresponding output variable of the if-node (if present) must be compatible with both S1 and S2 as it represents the union of both possible shapes.For example, if in a model file, the the first output of `then_branch` is typed float tensor with shape [2] and the first output of `else_branch` is another float tensor with shape [3], If's first output should have (a) no shape set, or (b) a shape of rank 1 with neither `dim_value` nor `dim_param` set, or (c) a shape of rank 1 with a unique `dim_param`. In contrast, the first output cannot have the shape [2] since [2] and [3] are not compatible.</dd>
64416443
</dl>
64426444

64436445
#### Type Constraints

onnx/defs/controlflow/defs.cc

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -261,10 +261,8 @@ void IfInferenceFunction(InferenceContext& ctx) {
261261
else_elem_type);
262262
}
263263

264-
// merge the 'else' shape information to check it's consistent and
265-
// augment the 'if' output if possible
266-
mergeInShapeInfo(
267-
else_output->tensor_type(), *if_output->mutable_tensor_type());
264+
UnionShapeInfo(
265+
else_output->tensor_type().shape(), *if_output->mutable_tensor_type());
268266
}
269267
}
270268
}
@@ -382,16 +380,29 @@ void LoopInferenceFunction(InferenceContext& ctx) {
382380

383381
ONNX_OPERATOR_SET_SCHEMA(
384382
If,
385-
1,
383+
11,
386384
OpSchema()
387385
.SetDoc("If conditional")
388386
.Input(0, "cond", "Condition for the if", "B")
389387
.Output(
390388
0,
391389
"outputs",
392390
"Values that are live-out to the enclosing scope. The return values in "
393-
"the `then_branch` and `else_branch` must be of the same shape and same "
394-
"data type.",
391+
"the `then_branch` and `else_branch` must be of the same data type. "
392+
"The `then_branch` and `else_branch` may produce tensors with the same "
393+
"element type and different shapes. "
394+
"If corresponding outputs from the then-branch and the else-branch have "
395+
"static shapes S1 and S2, then the shape of the corresponding output "
396+
"variable of the if-node (if present) must be compatible with both S1 "
397+
"and S2 as it represents the union of both possible shapes."
398+
"For example, if in a model file, the the first "
399+
"output of `then_branch` is typed float tensor with shape [2] and the "
400+
"first output of `else_branch` is another float tensor with shape [3], "
401+
"If's first output should have (a) no shape set, or (b) "
402+
"a shape of rank 1 with neither `dim_value` nor `dim_param` set, or (c) "
403+
"a shape of rank 1 with a unique `dim_param`. "
404+
"In contrast, the first output cannot have the shape [2] since [2] and "
405+
"[3] are not compatible.",
395406
"V",
396407
OpSchema::Variadic,
397408
false)

onnx/defs/controlflow/old.cc

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -989,4 +989,119 @@ ONNX_OPERATOR_SET_SCHEMA(
989989
.TypeConstraint("I", {"tensor(int64)"}, "Int64 tensor")
990990
.TypeConstraint("V", OpSchema::all_tensor_types(), "All Tensor types")
991991
.TypeAndShapeInferenceFunction(ScanInferenceFunctionOpset9));
992+
993+
void IfInferenceFunction1(InferenceContext& ctx) {
994+
// there are no inputs so we just need to run the subgraph inferencing for
995+
// then/else subgraphs and apply those to the outputs.
996+
std::vector<const TypeProto*> subgraph_input_types; // none
997+
std::vector<const TensorProto*> input_data; // none
998+
999+
std::vector<const TypeProto*> then_output_types;
1000+
std::vector<const TypeProto*> else_output_types;
1001+
1002+
// Run inferencing on the subgraph
1003+
GraphInferencer* graphInferencer =
1004+
ctx.getGraphAttributeInferencer("then_branch");
1005+
if (graphInferencer) {
1006+
then_output_types =
1007+
graphInferencer->doInferencing(subgraph_input_types, input_data);
1008+
}
1009+
1010+
graphInferencer = ctx.getGraphAttributeInferencer("else_branch");
1011+
if (graphInferencer) {
1012+
else_output_types =
1013+
graphInferencer->doInferencing(subgraph_input_types, input_data);
1014+
}
1015+
1016+
auto num_outputs = ctx.getNumOutputs();
1017+
auto num_then_outputs = then_output_types.size();
1018+
auto num_else_outputs = else_output_types.size();
1019+
1020+
// the output types for then and else should be the same
1021+
if (num_then_outputs != num_else_outputs) {
1022+
fail_type_inference(
1023+
"then_branch and else_branch produce different number of outputs. ",
1024+
num_then_outputs,
1025+
" != ",
1026+
num_else_outputs);
1027+
}
1028+
1029+
if (num_then_outputs != num_outputs) {
1030+
fail_type_inference(
1031+
"If node has ",
1032+
num_outputs,
1033+
" but subgraphs produce ",
1034+
num_then_outputs);
1035+
}
1036+
1037+
for (size_t i = 0, end = then_output_types.size(); i < end; ++i) {
1038+
auto then_output = then_output_types[i];
1039+
auto else_output = else_output_types[i];
1040+
1041+
if (then_output->value_case() != else_output->value_case()) {
1042+
fail_type_inference(
1043+
"Mismatched type for output ",
1044+
i,
1045+
" then=",
1046+
then_output->value_case(),
1047+
" else=",
1048+
else_output->value_case());
1049+
}
1050+
1051+
auto* if_output = ctx.getOutputType(i);
1052+
*if_output = *then_output;
1053+
1054+
if (then_output->has_tensor_type()) {
1055+
auto then_elem_type = then_output->tensor_type().elem_type();
1056+
auto else_elem_type = else_output->tensor_type().elem_type();
1057+
1058+
if (then_elem_type != else_elem_type) {
1059+
fail_type_inference(
1060+
"Mismatched tensor element type for output ",
1061+
i,
1062+
" then=",
1063+
then_elem_type,
1064+
" else=",
1065+
else_elem_type);
1066+
}
1067+
1068+
// merge the 'else' shape information to check it's consistent and
1069+
// augment the 'if' output if possible
1070+
mergeInShapeInfo(
1071+
else_output->tensor_type(), *if_output->mutable_tensor_type());
1072+
}
1073+
}
1074+
}
1075+
1076+
ONNX_OPERATOR_SET_SCHEMA(
1077+
If,
1078+
1,
1079+
OpSchema()
1080+
.SetDoc("If conditional")
1081+
.Input(0, "cond", "Condition for the if", "B")
1082+
.Output(
1083+
0,
1084+
"outputs",
1085+
"Values that are live-out to the enclosing scope. The return values in "
1086+
"the `then_branch` and `else_branch` must be of the same shape and same "
1087+
"data type.",
1088+
"V",
1089+
OpSchema::Variadic,
1090+
false)
1091+
.Attr(
1092+
"then_branch",
1093+
"Graph to run if condition is true. Has N outputs: values you wish to "
1094+
"be live-out to the enclosing scope. The number of outputs must match"
1095+
" the number of outputs in the else_branch.",
1096+
AttributeProto::GRAPH)
1097+
.Attr(
1098+
"else_branch",
1099+
"Graph to run if condition is false. Has N outputs: values you wish to"
1100+
" be live-out to the enclosing scope. The number of outputs must match"
1101+
" the number of outputs in the then_branch.",
1102+
AttributeProto::GRAPH)
1103+
.TypeConstraint("V", OpSchema::all_tensor_types(), "All Tensor types")
1104+
.TypeConstraint("B", {"tensor(bool)"}, "Only bool")
1105+
.TypeAndShapeInferenceFunction(IfInferenceFunction1));
1106+
9921107
} // namespace ONNX_NAMESPACE

onnx/defs/operator_sets.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 11, SplitToSequence);
633633
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 11, ConcatFromSequence);
634634
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 11, Pad);
635635
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 11, Gemm);
636+
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 11, If);
636637

637638
// Iterate over schema from ai.onnx version 11
638639
class OpSet_Onnx_ver11 {
@@ -702,6 +703,7 @@ class OpSet_Onnx_ver11 {
702703
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 11, ConcatFromSequence)>());
703704
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 11, Pad)>());
704705
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 11, Gemm)>());
706+
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 11, If)>());
705707
}
706708
};
707709

onnx/defs/sequence/defs.cc

Lines changed: 1 addition & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -9,60 +9,6 @@
99

1010
namespace ONNX_NAMESPACE {
1111

12-
// target-shape = Union (target-shape, source_shape)
13-
// Example 1: same rank, different dimensions
14-
// input1 shape: (2, 3, 4, 'x')
15-
// input2 shape: (2, 'y', 5, 'x')
16-
// output shape: (2, None, None, 'x')
17-
// Example 2: different rank
18-
// input1 shape: (2, 3, 4, 'x')
19-
// input2 shape: (2, 3, 4)
20-
// output shape: None
21-
void UnionShapeInfo(
22-
const TensorShapeProto& source_shape,
23-
TypeProto_Tensor& target_type) {
24-
if (target_type.has_shape()) {
25-
TensorShapeProto* target_shape = target_type.mutable_shape();
26-
27-
auto source_rank = source_shape.dim_size();
28-
auto target_rank = target_shape->dim_size();
29-
if (source_rank != target_rank) {
30-
target_type.clear_shape();
31-
return;
32-
}
33-
34-
for (int i = 0; i < source_rank; ++i) {
35-
const auto source_dim = source_shape.dim(i);
36-
const auto target_dim = target_shape->dim(i);
37-
bool is_dims_conflict = [&](){
38-
if (source_dim.has_dim_value()) {
39-
if (target_dim.has_dim_value() &&
40-
target_dim.dim_value() == source_dim.dim_value()) {
41-
return false;
42-
}
43-
return true;
44-
}
45-
46-
if (source_dim.has_dim_param()) {
47-
if (target_dim.has_dim_param() &&
48-
target_dim.dim_param() == source_dim.dim_param()) {
49-
return false;
50-
}
51-
return true;
52-
}
53-
54-
return (target_dim.has_dim_value() || target_dim.has_dim_param());
55-
}();
56-
if (is_dims_conflict &&
57-
(target_dim.has_dim_value() || target_dim.has_dim_param())) {
58-
auto dim = target_shape->mutable_dim(i);
59-
dim->clear_dim_value();
60-
dim->clear_dim_param();
61-
}
62-
}
63-
}
64-
}
65-
6612
static const char* SequenceEmpty_ver11_doc = R"DOC(
6713
Construct an empty tensor sequence, with given data type.
6814
)DOC";
@@ -662,4 +608,4 @@ ONNX_OPERATOR_SET_SCHEMA(
662608
}
663609
}));
664610

665-
} // namespace ONNX_NAMESPACE
611+
} // namespace ONNX_NAMESPACE

onnx/defs/shape_inference.h

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -746,4 +746,58 @@ inline void unifyDim(Dim& dim, int64_t value) {
746746
dim.set_dim_value(value);
747747
}
748748

749+
// target-shape = Union (target-shape, source_shape)
750+
// Example 1: same rank, different dimensions
751+
// input1 shape: (2, 3, 4, 'x')
752+
// input2 shape: (2, 'y', 5, 'x')
753+
// output shape: (2, None, None, 'x')
754+
// Example 2: different rank
755+
// input1 shape: (2, 3, 4, 'x')
756+
// input2 shape: (2, 3, 4)
757+
// output shape: None
758+
inline void UnionShapeInfo(
759+
const TensorShapeProto& source_shape,
760+
TypeProto_Tensor& target_type) {
761+
if (target_type.has_shape()) {
762+
TensorShapeProto* target_shape = target_type.mutable_shape();
763+
764+
auto source_rank = source_shape.dim_size();
765+
auto target_rank = target_shape->dim_size();
766+
if (source_rank != target_rank) {
767+
target_type.clear_shape();
768+
return;
769+
}
770+
771+
for (int i = 0; i < source_rank; ++i) {
772+
const auto source_dim = source_shape.dim(i);
773+
const auto target_dim = target_shape->dim(i);
774+
bool is_dims_conflict = [&](){
775+
if (source_dim.has_dim_value()) {
776+
if (target_dim.has_dim_value() &&
777+
target_dim.dim_value() == source_dim.dim_value()) {
778+
return false;
779+
}
780+
return true;
781+
}
782+
783+
if (source_dim.has_dim_param()) {
784+
if (target_dim.has_dim_param() &&
785+
target_dim.dim_param() == source_dim.dim_param()) {
786+
return false;
787+
}
788+
return true;
789+
}
790+
791+
return (target_dim.has_dim_value() || target_dim.has_dim_param());
792+
}();
793+
if (is_dims_conflict &&
794+
(target_dim.has_dim_value() || target_dim.has_dim_param())) {
795+
auto dim = target_shape->mutable_dim(i);
796+
dim->clear_dim_value();
797+
dim->clear_dim_param();
798+
}
799+
}
800+
}
801+
}
802+
749803
} // namespace ONNX_NAMESPACE

0 commit comments

Comments
 (0)