Skip to content

Commit a3c9145

Browse files
authored
Bump NMS version for avoiding regression in existing models (#2348)
* Bump NMS version for avoiding regression in existing models * Bring old logic back
1 parent ab6b942 commit a3c9145

5 files changed

Lines changed: 125 additions & 2 deletions

File tree

docs/Changelog.md

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11916,6 +11916,52 @@ This version of the operator has been available since version 11 of the default
1191611916
<dd>Constrain index tensor to int64</dd>
1191711917
</dl>
1191811918

11919+
### <a name="NonMaxSuppression-11"></a>**NonMaxSuppression-11**</a>
11920+
11921+
Filter out boxes that have high intersection-over-union (IOU) overlap with previously selected boxes.
11922+
Bounding boxes with score less than score_threshold are removed. Bounding box format is indicated by attribute center_point_box.
11923+
Note that this algorithm is agnostic to where the origin is in the coordinate system and more generally is invariant to
11924+
orthogonal transformations and translations of the coordinate system; thus translating or reflections of the coordinate system
11925+
result in the same boxes being selected by the algorithm.
11926+
The selected_indices output is a set of integers indexing into the input collection of bounding boxes representing the selected boxes.
11927+
The bounding box coordinates corresponding to the selected indices can then be obtained using the Gather or GatherND operation.
11928+
11929+
#### Version
11930+
11931+
This version of the operator has been available since version 11 of the default ONNX operator set.
11932+
11933+
#### Attributes
11934+
11935+
<dl>
11936+
<dt><tt>center_point_box</tt> : int (default is 0)</dt>
11937+
<dd>Integer indicate the format of the box data. The default is 0. 0 - the box data is supplied as [y1, x1, y2, x2] where (y1, x1) and (y2, x2) are the coordinates of any diagonal pair of box corners and the coordinates can be provided as normalized (i.e., lying in the interval [0, 1]) or absolute. Mostly used for TF models. 1 - the box data is supplied as [x_center, y_center, width, height]. Mostly used for Pytorch models.</dd>
11938+
</dl>
11939+
11940+
#### Inputs (2 - 5)
11941+
11942+
<dl>
11943+
<dt><tt>boxes</tt> : tensor(float)</dt>
11944+
<dd>An input tensor with shape [num_batches, spatial_dimension, 4]. The single box data format is indicated by center_point_box.</dd>
11945+
<dt><tt>scores</tt> : tensor(float)</dt>
11946+
<dd>An input tensor with shape [num_batches, num_classes, spatial_dimension]</dd>
11947+
<dt><tt>max_output_boxes_per_class</tt> (optional) : tensor(int64)</dt>
11948+
<dd>Integer representing the maximum number of boxes to be selected per batch per class. It is a scalar. Default to 0, which means no output.</dd>
11949+
<dt><tt>iou_threshold</tt> (optional) : tensor(float)</dt>
11950+
<dd>Float representing the threshold for deciding whether boxes overlap too much with respect to IOU. It is scalar. Value range [0, 1]. Default to 0.</dd>
11951+
<dt><tt>score_threshold</tt> (optional) : tensor(float)</dt>
11952+
<dd>Float representing the threshold for deciding when to remove boxes based on score. It is a scalar.</dd>
11953+
</dl>
11954+
11955+
#### Outputs
11956+
11957+
<dl>
11958+
<dt><tt>selected_indices</tt> : tensor(int64)</dt>
11959+
<dd>selected indices from the boxes tensor. [num_selected_indices, 3], the selected index format is [batch_index, class_index, box_index].</dd>
11960+
</dl>
11961+
11962+
#### Type Constraints
11963+
11964+
1191911965
### <a name="OneHot-11"></a>**OneHot-11**</a>
1192011966

1192111967
Produces a one-hot tensor based on inputs.

docs/Operators.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9378,7 +9378,9 @@ expect(node, inputs=[x], outputs=[y],
93789378

93799379
#### Version
93809380

9381-
This version of the operator has been available since version 10 of the default ONNX operator set.
9381+
This version of the operator has been available since version 11 of the default ONNX operator set.
9382+
9383+
Other versions of this operator: <a href="Changelog.md#NonMaxSuppression-10">NonMaxSuppression-10</a>
93829384

93839385
#### Attributes
93849386

onnx/defs/object_detection/defs.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ The bounding box coordinates corresponding to the selected indices can then be o
138138

139139
ONNX_OPERATOR_SET_SCHEMA(
140140
NonMaxSuppression,
141-
10,
141+
11,
142142
OpSchema()
143143
.Input(
144144
0,

onnx/defs/object_detection/old.cc

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
// Copyright (c) Facebook Inc. and Microsoft Corporation.
2+
// Licensed under the MIT license.
3+
4+
#include "onnx/defs/schema.h"
5+
using namespace ONNX_NAMESPACE;
6+
7+
namespace ONNX_NAMESPACE {
8+
9+
static const char* NonMaxSuppression_doc = R"DOC(
10+
Filter out boxes that have high intersection-over-union (IOU) overlap with previously selected boxes.
11+
Bounding boxes with score less than score_threshold are removed. Bounding box format is indicated by attribute center_point_box.
12+
Note that this algorithm is agnostic to where the origin is in the coordinate system and more generally is invariant to
13+
orthogonal transformations and translations of the coordinate system; thus translating or reflections of the coordinate system
14+
result in the same boxes being selected by the algorithm.
15+
The selected_indices output is a set of integers indexing into the input collection of bounding boxes representing the selected boxes.
16+
The bounding box coordinates corresponding to the selected indices can then be obtained using the Gather or GatherND operation.
17+
)DOC";
18+
19+
ONNX_OPERATOR_SET_SCHEMA(
20+
NonMaxSuppression,
21+
10,
22+
OpSchema()
23+
.Input(
24+
0,
25+
"boxes",
26+
"An input tensor with shape [num_batches, spatial_dimension, 4]. The single box data format is indicated by center_point_box.",
27+
"tensor(float)")
28+
.Input(
29+
1,
30+
"scores",
31+
"An input tensor with shape [num_batches, num_classes, spatial_dimension]",
32+
"tensor(float)")
33+
.Input(
34+
2,
35+
"max_output_boxes_per_class",
36+
"Integer representing the maximum number of boxes to be selected per batch per class. It is a scalar. Default to 0, which means no output.",
37+
"tensor(int64)",
38+
OpSchema::Optional)
39+
.Input(
40+
3,
41+
"iou_threshold",
42+
"Float representing the threshold for deciding whether boxes overlap too much with respect to IOU. It is scalar. Value range [0, 1]. Default to 0.",
43+
"tensor(float)",
44+
OpSchema::Optional)
45+
.Input(
46+
4,
47+
"score_threshold",
48+
"Float representing the threshold for deciding when to remove boxes based on score. It is a scalar.",
49+
"tensor(float)",
50+
OpSchema::Optional)
51+
.Output(
52+
0,
53+
"selected_indices",
54+
"selected indices from the boxes tensor. [num_selected_indices, 3], the selected index format is [batch_index, class_index, box_index].",
55+
"tensor(int64)")
56+
.Attr(
57+
"center_point_box",
58+
"Integer indicate the format of the box data. The default is 0. "
59+
"0 - the box data is supplied as [y1, x1, y2, x2] where (y1, x1) and (y2, x2) are the coordinates of any diagonal pair of box corners "
60+
"and the coordinates can be provided as normalized (i.e., lying in the interval [0, 1]) or absolute. Mostly used for TF models. "
61+
"1 - the box data is supplied as [x_center, y_center, width, height]. Mostly used for Pytorch models.",
62+
AttributeProto::INT,
63+
static_cast<int64_t>(0))
64+
.SetDoc(NonMaxSuppression_doc)
65+
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
66+
auto selected_indices_type =
67+
ctx.getOutputType(0)->mutable_tensor_type();
68+
selected_indices_type->set_elem_type(
69+
::ONNX_NAMESPACE::TensorProto_DataType::
70+
TensorProto_DataType_INT64);
71+
}));
72+
73+
} // namespace ONNX_NAMESPACE

onnx/defs/operator_sets.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 11, ConcatFromSequence);
634634
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 11, Pad);
635635
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 11, Gemm);
636636
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 11, If);
637+
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 11, NonMaxSuppression);
637638

638639
// Iterate over schema from ai.onnx version 11
639640
class OpSet_Onnx_ver11 {
@@ -704,6 +705,7 @@ class OpSet_Onnx_ver11 {
704705
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 11, Pad)>());
705706
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 11, Gemm)>());
706707
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 11, If)>());
708+
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 11, NonMaxSuppression)>());
707709
}
708710
};
709711

0 commit comments

Comments
 (0)