Skip to content

Commit eb4b7c2

Browse files
authored
allow variadic parameters of different types (#1615)
* allow variadic parameters of different types * fix formatting issues * add isHomogeneous to defs.pyi
1 parent 4166246 commit eb4b7c2

File tree

8 files changed

+100
-42
lines changed

8 files changed

+100
-42
lines changed

docs/Changelog.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1839,7 +1839,7 @@ This version of the operator has been available since version 1 of the default O
18391839
#### Outputs (1 - ∞)
18401840

18411841
<dl>
1842-
<dt><tt>outputs</tt> (variadic) : V</dt>
1842+
<dt><tt>outputs</tt> (variadic, heterogeneous) : V</dt>
18431843
<dd>Values that are live-out to the enclosing scope. The return values in the `then_branch` and `else_branch` must be of the same shape and same data type.</dd>
18441844
</dl>
18451845

@@ -2445,14 +2445,14 @@ This version of the operator has been available since version 1 of the default O
24452445
<dd>A maximum trip-count for the loop specified at runtime. Optional. pass empty string to skip.</dd>
24462446
<dt><tt>cond</tt> : B</dt>
24472447
<dd>A boolean termination condition. Pass empty string to skip.</dd>
2448-
<dt><tt>v_initial</tt> (variadic) : V</dt>
2448+
<dt><tt>v_initial</tt> (variadic, heterogeneous) : V</dt>
24492449
<dd>The initial values of any loop-carried dependencies (values that change across loop iterations)</dd>
24502450
</dl>
24512451

24522452
#### Outputs (1 - &#8734;)
24532453

24542454
<dl>
2455-
<dt><tt>v_final_and_scan_outputs</tt> (variadic) : V</dt>
2455+
<dt><tt>v_final_and_scan_outputs</tt> (variadic, heterogeneous) : V</dt>
24562456
<dd>Final N loop carried dependency values then K scan_outputs</dd>
24572457
</dl>
24582458

@@ -8288,14 +8288,14 @@ This version of the operator has been available since version 8 of the default O
82888288
<dl>
82898289
<dt><tt>sequence_lens</tt> (optional) : I</dt>
82908290
<dd>Optional tensor specifying lengths of the sequences in a batch. If this input is not specified, all sequences are assumed to be of the maximum sequence length (the dimension of the sequence axis of the scan_input tensors).</dd>
8291-
<dt><tt>initial_state_and_scan_inputs</tt> (variadic) : V</dt>
8291+
<dt><tt>initial_state_and_scan_inputs</tt> (variadic, heterogeneous) : V</dt>
82928292
<dd>Initial values of the loop's N state variables followed by M scan_inputs</dd>
82938293
</dl>
82948294

82958295
#### Outputs (1 - &#8734;)
82968296

82978297
<dl>
8298-
<dt><tt>final_state_and_scan_outputs</tt> (variadic) : V</dt>
8298+
<dt><tt>final_state_and_scan_outputs</tt> (variadic, heterogeneous) : V</dt>
82998299
<dd>Final values of the loop's N state variables followed by K scan_outputs</dd>
83008300
</dl>
83018301

docs/Operators.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4711,7 +4711,7 @@ This version of the operator has been available since version 1 of the default O
47114711
#### Outputs (1 - &#8734;)
47124712

47134713
<dl>
4714-
<dt><tt>outputs</tt> (variadic) : V</dt>
4714+
<dt><tt>outputs</tt> (variadic, heterogeneous) : V</dt>
47154715
<dd>Values that are live-out to the enclosing scope. The return values in the `then_branch` and `else_branch` must be of the same shape and same data type.</dd>
47164716
</dl>
47174717

@@ -5695,14 +5695,14 @@ This version of the operator has been available since version 1 of the default O
56955695
<dd>A maximum trip-count for the loop specified at runtime. Optional. pass empty string to skip.</dd>
56965696
<dt><tt>cond</tt> : B</dt>
56975697
<dd>A boolean termination condition. Pass empty string to skip.</dd>
5698-
<dt><tt>v_initial</tt> (variadic) : V</dt>
5698+
<dt><tt>v_initial</tt> (variadic, heterogeneous) : V</dt>
56995699
<dd>The initial values of any loop-carried dependencies (values that change across loop iterations)</dd>
57005700
</dl>
57015701

57025702
#### Outputs (1 - &#8734;)
57035703

57045704
<dl>
5705-
<dt><tt>v_final_and_scan_outputs</tt> (variadic) : V</dt>
5705+
<dt><tt>v_final_and_scan_outputs</tt> (variadic, heterogeneous) : V</dt>
57065706
<dd>Final N loop carried dependency values then K scan_outputs</dd>
57075707
</dl>
57085708

@@ -9705,14 +9705,14 @@ This version of the operator has been available since version 8 of the default O
97059705
<dl>
97069706
<dt><tt>sequence_lens</tt> (optional) : I</dt>
97079707
<dd>Optional tensor specifying lengths of the sequences in a batch. If this input is not specified, all sequences are assumed to be of the maximum sequence length (the dimension of the sequence axis of the scan_input tensors).</dd>
9708-
<dt><tt>initial_state_and_scan_inputs</tt> (variadic) : V</dt>
9708+
<dt><tt>initial_state_and_scan_inputs</tt> (variadic, heterogeneous) : V</dt>
97099709
<dd>Initial values of the loop's N state variables followed by M scan_inputs</dd>
97109710
</dl>
97119711

97129712
#### Outputs (1 - &#8734;)
97139713

97149714
<dl>
9715-
<dt><tt>final_state_and_scan_outputs</tt> (variadic) : V</dt>
9715+
<dt><tt>final_state_and_scan_outputs</tt> (variadic, heterogeneous) : V</dt>
97169716
<dd>Final values of the loop's N state variables followed by K scan_outputs</dd>
97179717
</dl>
97189718

onnx/cpp2py_export.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,9 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) {
8181
.def_property_readonly("typeStr", &OpSchema::FormalParameter::GetTypeStr)
8282
.def_property_readonly(
8383
"description", &OpSchema::FormalParameter::GetDescription)
84-
.def_property_readonly("option", &OpSchema::FormalParameter::GetOption);
84+
.def_property_readonly("option", &OpSchema::FormalParameter::GetOption)
85+
.def_property_readonly(
86+
"isHomogeneous", &OpSchema::FormalParameter::GetIsHomogeneous);
8587

8688
py::enum_<AttributeProto::AttributeType>(op_schema, "AttrType")
8789
.value("FLOAT", AttributeProto::FLOAT)

onnx/defs/controlflow/defs.cc

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,8 @@ ONNX_OPERATOR_SET_SCHEMA(
396396
"the `then_branch` and `else_branch` must be of the same shape and same "
397397
"data type.",
398398
"V",
399-
OpSchema::Variadic)
399+
OpSchema::Variadic,
400+
false)
400401
.Attr(
401402
"then_branch",
402403
"Graph to run if condition is true. Has N outputs: values you wish to "
@@ -551,13 +552,15 @@ ONNX_OPERATOR_SET_SCHEMA(
551552
"The initial values of any loop-carried dependencies (values that "
552553
"change across loop iterations)",
553554
"V",
554-
OpSchema::Variadic)
555+
OpSchema::Variadic,
556+
false)
555557
.Output(
556558
0,
557559
"v_final_and_scan_outputs",
558560
"Final N loop carried dependency values then K scan_outputs",
559561
"V",
560-
OpSchema::Variadic)
562+
OpSchema::Variadic,
563+
false)
561564
.Attr(
562565
"body",
563566
"The graph run each iteration. It has 2+N inputs: (iteration_num, "
@@ -720,13 +723,15 @@ ONNX_OPERATOR_SET_SCHEMA(
720723
"initial_state_and_scan_inputs",
721724
"Initial values of the loop's N state variables followed by M scan_inputs",
722725
"V",
723-
OpSchema::Variadic)
726+
OpSchema::Variadic,
727+
false)
724728
.Output(
725729
0,
726730
"final_state_and_scan_outputs",
727731
"Final values of the loop's N state variables followed by K scan_outputs",
728732
"V",
729-
OpSchema::Variadic)
733+
OpSchema::Variadic,
734+
false)
730735
.Attr(
731736
"body",
732737
"The graph run each iteration. It has N+M inputs: "

onnx/defs/gen_doc.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,10 @@ def format_value(value): # type: (Any) -> Text
152152
if OpSchema.FormalParameterOption.Optional == input.option:
153153
option_str = " (optional)"
154154
elif OpSchema.FormalParameterOption.Variadic == input.option:
155-
option_str = " (variadic)"
155+
if input.isHomogeneous:
156+
option_str = " (variadic)"
157+
else:
158+
option_str = " (variadic, heterogeneous)"
156159
s += '<dt><tt>{}</tt>{} : {}</dt>\n'.format(input.name, option_str, input.typeStr)
157160
s += '<dd>{}</dd>\n'.format(input.description)
158161
s += '</dl>\n'
@@ -171,7 +174,10 @@ def format_value(value): # type: (Any) -> Text
171174
if OpSchema.FormalParameterOption.Optional == output.option:
172175
option_str = " (optional)"
173176
elif OpSchema.FormalParameterOption.Variadic == output.option:
174-
option_str = " (variadic)"
177+
if output.isHomogeneous:
178+
option_str = " (variadic)"
179+
else:
180+
option_str = " (variadic, heterogeneous)"
175181
s += '<dt><tt>{}</tt>{} : {}</dt>\n'.format(output.name, option_str, output.typeStr)
176182
s += '<dd>{}</dd>\n'.format(output.description)
177183
s += '</dl>\n'

onnx/defs/schema.cc

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -32,22 +32,26 @@ OpSchema::FormalParameter::FormalParameter(
3232
DataTypeSet allowed_type_set,
3333
std::string type_str,
3434
std::string description,
35-
FormalParameterOption param_option)
35+
FormalParameterOption param_option,
36+
bool is_homogeneous)
3637
: name_(std::move(name)),
3738
type_set_(std::move(allowed_type_set)),
3839
type_str_(std::move(type_str)),
3940
description_(std::move(description)),
40-
param_option_(param_option) {}
41+
param_option_(param_option),
42+
is_homogeneous_(is_homogeneous) {}
4143

4244
OpSchema::FormalParameter::FormalParameter(
4345
std::string name,
4446
std::string description,
4547
std::string type_str,
46-
FormalParameterOption param_option)
48+
FormalParameterOption param_option,
49+
bool is_homogeneous)
4750
: name_(std::move(name)),
4851
type_str_(std::move(type_str)),
4952
description_(std::move(description)),
50-
param_option_(param_option) {}
53+
param_option_(param_option),
54+
is_homogeneous_(is_homogeneous) {}
5155

5256
const std::string& OpSchema::FormalParameter::GetName() const {
5357
return name_;
@@ -73,14 +77,22 @@ OpSchema::FormalParameterOption OpSchema::FormalParameter::GetOption() const {
7377
return param_option_;
7478
}
7579

80+
bool OpSchema::FormalParameter::GetIsHomogeneous() const {
81+
return is_homogeneous_;
82+
}
83+
7684
OpSchemaRegistry* OpSchemaRegistry::Instance() {
7785
static OpSchemaRegistry instance;
7886
return &instance;
7987
}
8088

8189
void OpSchema::Verify(const NodeProto& node) const {
8290
if (deprecated_) {
83-
fail_check("Operator '", name_, "' has been deprecated since version ", since_version_);
91+
fail_check(
92+
"Operator '",
93+
name_,
94+
"' has been deprecated since version ",
95+
since_version_);
8496
}
8597

8698
// Check the number of inputs.
@@ -101,7 +113,9 @@ void OpSchema::Verify(const NodeProto& node) const {
101113
fail_check(
102114
"Node (",
103115
node.name(),
104-
") has input size ", node.input_size(), " not in allowed input sizes.");
116+
") has input size ",
117+
node.input_size(),
118+
" not in allowed input sizes.");
105119
}
106120

107121
// Check the number of outputs.
@@ -122,7 +136,9 @@ void OpSchema::Verify(const NodeProto& node) const {
122136
fail_check(
123137
"Node (",
124138
node.name(),
125-
"has output size ", node.output_size(), " not in allowed output sizes.");
139+
"has output size ",
140+
node.output_size(),
141+
" not in allowed output sizes.");
126142
}
127143

128144
// Check the values of inputs / outputs
@@ -203,7 +219,8 @@ void OpSchema::Verify(const NodeProto& node) const {
203219
} else if (allows_unchecked_attributes_ || isInternalSymbol(name)) {
204220
continue;
205221
} else {
206-
fail_check("Unrecognized attribute: ", name, " for operator ", node.op_type());
222+
fail_check(
223+
"Unrecognized attribute: ", name, " for operator ", node.op_type());
207224
}
208225

209226
if (attr_proto.has_ref_attr_name()) {
@@ -491,15 +508,17 @@ OpSchema& OpSchema::Input(
491508
std::string name,
492509
std::string description,
493510
std::string type_str,
494-
OpSchema::FormalParameterOption param_option) {
511+
OpSchema::FormalParameterOption param_option,
512+
bool is_homogeneous) {
495513
if (int(inputs_.size()) <= n) {
496514
inputs_.resize(n + 1);
497515
}
498516
inputs_[n] = FormalParameter(
499517
std::move(name),
500518
std::move(description),
501519
std::move(type_str),
502-
param_option);
520+
param_option,
521+
is_homogeneous);
503522
return *this;
504523
}
505524

@@ -508,29 +527,33 @@ OpSchema& OpSchema::Input(
508527
const char* name,
509528
const char* description,
510529
const char* type_str,
511-
FormalParameterOption param_option) {
530+
FormalParameterOption param_option,
531+
bool is_homogeneous) {
512532
return Input(
513533
n,
514534
std::string(name),
515535
std::string(description),
516536
std::string(type_str),
517-
param_option);
537+
param_option,
538+
is_homogeneous);
518539
}
519540

520541
OpSchema& OpSchema::Output(
521542
int n,
522543
std::string name,
523544
std::string description,
524545
std::string type_str,
525-
OpSchema::FormalParameterOption param_option) {
546+
OpSchema::FormalParameterOption param_option,
547+
bool is_homogeneous) {
526548
if (int(outputs_.size()) <= n) {
527549
outputs_.resize(n + 1);
528550
}
529551
outputs_[n] = FormalParameter(
530552
std::move(name),
531553
std::move(description),
532554
std::move(type_str),
533-
param_option);
555+
param_option,
556+
is_homogeneous);
534557
return *this;
535558
}
536559

@@ -539,13 +562,15 @@ OpSchema& OpSchema::Output(
539562
const char* name,
540563
const char* description,
541564
const char* type_str,
542-
FormalParameterOption param_option) {
565+
FormalParameterOption param_option,
566+
bool is_homogeneous) {
543567
return Output(
544568
n,
545569
std::string(name),
546570
std::string(description),
547571
std::string(type_str),
548-
param_option);
572+
param_option,
573+
is_homogeneous);
549574
}
550575

551576
OpSchema& OpSchema::TypeConstraint(
@@ -722,7 +747,8 @@ std::ostream& operator<<(std::ostream& out, const OpSchema& schema) {
722747
return out;
723748
}
724749

725-
OpSchemaRegistry::DomainToVersionRange& OpSchemaRegistry::DomainToVersionRange::Instance() {
750+
OpSchemaRegistry::DomainToVersionRange&
751+
OpSchemaRegistry::DomainToVersionRange::Instance() {
726752
static DomainToVersionRange domain_to_version_range;
727753
return domain_to_version_range;
728754
};

0 commit comments

Comments
 (0)