|
| 1 | +// Copyright (c) Facebook Inc. and Microsoft Corporation. |
| 2 | +// Licensed under the MIT license. |
| 3 | + |
| 4 | +#include "onnx/defs/schema.h" |
| 5 | +using namespace ONNX_NAMESPACE; |
| 6 | + |
| 7 | +namespace ONNX_NAMESPACE { |
| 8 | + |
| 9 | +static const char* NonMaxSuppression_doc = R"DOC( |
| 10 | +Filter out boxes that have high intersection-over-union (IOU) overlap with previously selected boxes. |
| 11 | +Bounding boxes with score less than score_threshold are removed. Bounding box format is indicated by attribute center_point_box. |
| 12 | +Note that this algorithm is agnostic to where the origin is in the coordinate system and more generally is invariant to |
| 13 | +orthogonal transformations and translations of the coordinate system; thus translating or reflections of the coordinate system |
| 14 | +result in the same boxes being selected by the algorithm. |
| 15 | +The selected_indices output is a set of integers indexing into the input collection of bounding boxes representing the selected boxes. |
| 16 | +The bounding box coordinates corresponding to the selected indices can then be obtained using the Gather or GatherND operation. |
| 17 | +)DOC"; |
| 18 | + |
| 19 | +ONNX_OPERATOR_SET_SCHEMA( |
| 20 | + NonMaxSuppression, |
| 21 | + 10, |
| 22 | + OpSchema() |
| 23 | + .Input( |
| 24 | + 0, |
| 25 | + "boxes", |
| 26 | + "An input tensor with shape [num_batches, spatial_dimension, 4]. The single box data format is indicated by center_point_box.", |
| 27 | + "tensor(float)") |
| 28 | + .Input( |
| 29 | + 1, |
| 30 | + "scores", |
| 31 | + "An input tensor with shape [num_batches, num_classes, spatial_dimension]", |
| 32 | + "tensor(float)") |
| 33 | + .Input( |
| 34 | + 2, |
| 35 | + "max_output_boxes_per_class", |
| 36 | + "Integer representing the maximum number of boxes to be selected per batch per class. It is a scalar. Default to 0, which means no output.", |
| 37 | + "tensor(int64)", |
| 38 | + OpSchema::Optional) |
| 39 | + .Input( |
| 40 | + 3, |
| 41 | + "iou_threshold", |
| 42 | + "Float representing the threshold for deciding whether boxes overlap too much with respect to IOU. It is scalar. Value range [0, 1]. Default to 0.", |
| 43 | + "tensor(float)", |
| 44 | + OpSchema::Optional) |
| 45 | + .Input( |
| 46 | + 4, |
| 47 | + "score_threshold", |
| 48 | + "Float representing the threshold for deciding when to remove boxes based on score. It is a scalar.", |
| 49 | + "tensor(float)", |
| 50 | + OpSchema::Optional) |
| 51 | + .Output( |
| 52 | + 0, |
| 53 | + "selected_indices", |
| 54 | + "selected indices from the boxes tensor. [num_selected_indices, 3], the selected index format is [batch_index, class_index, box_index].", |
| 55 | + "tensor(int64)") |
| 56 | + .Attr( |
| 57 | + "center_point_box", |
| 58 | + "Integer indicate the format of the box data. The default is 0. " |
| 59 | + "0 - the box data is supplied as [y1, x1, y2, x2] where (y1, x1) and (y2, x2) are the coordinates of any diagonal pair of box corners " |
| 60 | + "and the coordinates can be provided as normalized (i.e., lying in the interval [0, 1]) or absolute. Mostly used for TF models. " |
| 61 | + "1 - the box data is supplied as [x_center, y_center, width, height]. Mostly used for Pytorch models.", |
| 62 | + AttributeProto::INT, |
| 63 | + static_cast<int64_t>(0)) |
| 64 | + .SetDoc(NonMaxSuppression_doc) |
| 65 | + .TypeAndShapeInferenceFunction([](InferenceContext& ctx) { |
| 66 | + auto selected_indices_type = |
| 67 | + ctx.getOutputType(0)->mutable_tensor_type(); |
| 68 | + selected_indices_type->set_elem_type( |
| 69 | + ::ONNX_NAMESPACE::TensorProto_DataType:: |
| 70 | + TensorProto_DataType_INT64); |
| 71 | + })); |
| 72 | + |
| 73 | +} // namespace ONNX_NAMESPACE |
0 commit comments