Skip to content

Commit e08efaa

Browse files
hariharans29linkerzhang
authored andcommitted
Fix shape inference logic for TopK operator (#2005)
* Initial commit * Formatting * Refactor ParseRawData * Add raw data shape inference test for upsample
1 parent d80ea94 commit e08efaa

File tree

6 files changed

+125
-113
lines changed

6 files changed

+125
-113
lines changed

onnx/defs/math/defs.cc

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include <functional>
55
#include "onnx/defs/schema.h"
6+
#include "onnx/defs/tensor_proto_util.h"
67

78
namespace ONNX_NAMESPACE {
89

@@ -129,14 +130,13 @@ ONNX_OPERATOR_SET_SCHEMA(
129130
OpSchema::all_numeric_types(),
130131
"Constrain input and output types to high-precision numeric tensors.")
131132
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
132-
propagateElemTypeFromInputToOutput(ctx, 0, 0);
133-
if (hasNInputShapes(ctx, 2))
134-
bidirectionalBroadcastShapeInference(
135-
ctx.getInputType(0)->tensor_type().shape(),
136-
ctx.getInputType(1)->tensor_type().shape(),
137-
*ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape());
138-
})
139-
);
133+
propagateElemTypeFromInputToOutput(ctx, 0, 0);
134+
if (hasNInputShapes(ctx, 2))
135+
bidirectionalBroadcastShapeInference(
136+
ctx.getInputType(0)->tensor_type().shape(),
137+
ctx.getInputType(1)->tensor_type().shape(),
138+
*ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape());
139+
}));
140140

141141
ONNX_OPERATOR_SET_SCHEMA(
142142
Mul,
@@ -944,7 +944,6 @@ ONNX_OPERATOR_SET_SCHEMA(
944944
// Type inference:
945945
propagateElemTypeFromInputToOutput(ctx, 0, 0);
946946
updateOutputElemType(ctx, 1, TensorProto::INT64);
947-
948947
// Shape inference:
949948
if (!hasInputShape(ctx, 0))
950949
return;
@@ -955,29 +954,49 @@ ONNX_OPERATOR_SET_SCHEMA(
955954
axis += rank;
956955
if (axis < 0 || axis >= rank)
957956
fail_shape_inference("Invalid value for attribute axis");
958-
// TODO: unclear what results should be if axis has less than k
959-
// elements.
960-
// Infer output shape if 'K' is available
957+
958+
const auto& axis_dim = input_shape.dim(static_cast<int>(axis));
961959
const auto* k = ctx.getInputData(1);
962-
if (nullptr != k) {
963-
if (k->dims_size() != 1 || k->int64_data_size() != 1 ||
964-
k->data_type() != TensorProto::INT64)
960+
961+
// Infer output shape if:
962+
// (1) 'K' is available
963+
// (2) axis_dim has dim value
964+
// Othewise cannot reliably compute output shape as axis dim value is
965+
// unknown and hence cannot determine if axis dim value >= k (which
966+
// should be enforced)
967+
if (nullptr != k && axis_dim.has_dim_value()) {
968+
int64_t k_value = 0;
969+
if (k->dims_size() != 1 || k->dims(0) != 1)
970+
fail_shape_inference(
971+
"K input must be a one-dimensional tensor of size 1.");
972+
if (k->data_type() == TensorProto::INT64) {
973+
const auto& data = ParseData<int64_t>(k);
974+
k_value = data[0];
975+
} else
976+
fail_shape_inference("K input must be of type int64.");
977+
978+
if (axis_dim.dim_value() < k_value)
965979
fail_shape_inference(
966-
"K input must be a one-dimensional tensor of size 1 and of type int64.");
980+
"Axis has less than the requested k elements.");
981+
967982
TensorShapeProto result_shape = input_shape;
968983
result_shape.mutable_dim(static_cast<int>(axis))
969-
->set_dim_value(k->int64_data(0));
984+
->set_dim_value(k_value);
985+
970986
updateOutputShape(ctx, 0, result_shape);
971987
updateOutputShape(ctx, 1, result_shape);
972-
} else {
973-
// Infer output shapes' rank in any case
974-
auto* output_shape_0 = getOutputShape(ctx, 0);
975-
auto* output_shape_1 = getOutputShape(ctx, 1);
976-
for (int i = 0; i < input_shape.dim_size(); ++i) {
977-
output_shape_0->add_dim();
978-
output_shape_1->add_dim();
979-
}
988+
989+
return;
990+
}
991+
992+
// Infer output shapes' rank in any case
993+
auto* output_shape_0 = getOutputShape(ctx, 0);
994+
auto* output_shape_1 = getOutputShape(ctx, 1);
995+
for (int i = 0; i < input_shape.dim_size(); ++i) {
996+
output_shape_0->add_dim();
997+
output_shape_1->add_dim();
980998
}
999+
9811000
return;
9821001
}));
9831002

onnx/defs/tensor/defs.cc

Lines changed: 10 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -545,20 +545,11 @@ ONNX_OPERATOR_SET_SCHEMA(
545545
auto get_initializer_data =
546546
[](const TensorProto* initializer) -> std::vector<int64_t> {
547547
std::vector<int64_t> vec;
548-
if (initializer->has_raw_data() &&
549-
initializer->data_type() == TensorProto::INT64) {
550-
const auto& data = ParseRawData<int64_t>(initializer);
551-
vec.insert(vec.end(), data.begin(), data.end());
552-
} else if (
553-
initializer->has_raw_data() &&
554-
initializer->data_type() == TensorProto::INT32) {
555-
const auto& data = ParseRawData<int32_t>(initializer);
556-
vec.insert(vec.end(), data.begin(), data.end());
557-
} else if (initializer->data_type() == TensorProto::INT64) {
558-
const auto& data = initializer->int64_data();
548+
if (initializer->data_type() == TensorProto::INT64) {
549+
const auto& data = ParseData<int64_t>(initializer);
559550
vec.insert(vec.end(), data.begin(), data.end());
560551
} else if (initializer->data_type() == TensorProto::INT32) {
561-
const auto& data = initializer->int32_data();
552+
const auto& data = ParseData<int32_t>(initializer);
562553
vec.insert(vec.end(), data.begin(), data.end());
563554
} else {
564555
// unaccepted data type
@@ -1328,38 +1319,21 @@ ONNX_OPERATOR_SET_SCHEMA(
13281319
if (nullptr != scales) {
13291320
// Infer output shape's dimension value if 'scales' is known.
13301321
if (scales->data_type() == TensorProto::FLOAT) {
1331-
bool invalid_scale_shape = false;
1332-
if (scales->has_raw_data()) {
1333-
const auto& data = ParseRawData<float>(scales);
1334-
if (static_cast<int>(data.size()) == input_shape.dim_size()) {
1335-
for (int i = 0; i < input_shape.dim_size(); ++i) {
1336-
float dim_value =
1337-
static_cast<float>(input_shape.dim(i).dim_value());
1338-
output_shape->add_dim()->set_dim_value(static_cast<int64_t>(
1339-
std::floor(dim_value * data[i])));
1340-
}
1341-
} else {
1342-
invalid_scale_shape = true;
1343-
}
1344-
} else if (scales->float_data_size() == input_shape.dim_size()) {
1322+
const auto& data = ParseData<float>(scales);
1323+
if (static_cast<int>(data.size()) == input_shape.dim_size()) {
13451324
for (int i = 0; i < input_shape.dim_size(); ++i) {
13461325
float dim_value =
1347-
static_cast<float>(input_shape.dim(i).dim_value());
1348-
output_shape->add_dim()->set_dim_value(static_cast<int64_t>(
1349-
std::floor(dim_value * scales->float_data(i))));
1326+
static_cast<float>(input_shape.dim(i).dim_value());
1327+
output_shape->add_dim()->set_dim_value(
1328+
static_cast<int64_t>(std::floor(dim_value * data[i])));
13501329
}
13511330
} else {
1352-
invalid_scale_shape = true;
1353-
}
1354-
1355-
if (invalid_scale_shape) {
13561331
fail_shape_inference(
1357-
"Number of elements of input 'scales' must be same as rank of input 'X'."
1358-
);
1332+
"Number of elements of input 'scales' must be same as rank of input 'X'.");
13591333
}
13601334
} else {
13611335
fail_shape_inference(
1362-
"Input scales's element type must be float.");
1336+
"Input scales's element type must be float.");
13631337
}
13641338
} else {
13651339
// Infer output shape's rank in any case.

onnx/defs/tensor/old.cc

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -437,38 +437,21 @@ ONNX_OPERATOR_SET_SCHEMA(
437437
if (nullptr != scales) {
438438
// Infer output shape's dimension value if 'scales' is known.
439439
if (scales->data_type() == TensorProto::FLOAT) {
440-
bool invalid_scale_shape = false;
441-
if (scales->has_raw_data()) {
442-
const auto& data = ParseRawData<float>(scales);
443-
if (static_cast<int>(data.size()) == input_shape.dim_size()) {
444-
for (int i = 0; i < input_shape.dim_size(); ++i) {
445-
float dim_value =
446-
static_cast<float>(input_shape.dim(i).dim_value());
447-
output_shape->add_dim()->set_dim_value(static_cast<int64_t>(
448-
std::floor(dim_value * data[i])));
449-
}
450-
} else {
451-
invalid_scale_shape = true;
452-
}
453-
} else if (scales->float_data_size() == input_shape.dim_size()) {
440+
const auto& data = ParseData<float>(scales);
441+
if (static_cast<int>(data.size()) == input_shape.dim_size()) {
454442
for (int i = 0; i < input_shape.dim_size(); ++i) {
455443
float dim_value =
456-
static_cast<float>(input_shape.dim(i).dim_value());
457-
output_shape->add_dim()->set_dim_value(static_cast<int64_t>(
458-
std::floor(dim_value * scales->float_data(i))));
444+
static_cast<float>(input_shape.dim(i).dim_value());
445+
output_shape->add_dim()->set_dim_value(
446+
static_cast<int64_t>(std::floor(dim_value * data[i])));
459447
}
460448
} else {
461-
invalid_scale_shape = true;
462-
}
463-
464-
if (invalid_scale_shape){
465449
fail_shape_inference(
466-
"Number of elements of input 'scales' must be same as rank of input 'X'."
467-
);
450+
"Number of elements of input 'scales' must be same as rank of input 'X'.");
468451
}
469452
} else {
470453
fail_shape_inference(
471-
"Input scales's element type must be float.");
454+
"Input scales's element type must be float.");
472455
}
473456
} else {
474457
// Infer output shape's rank in any case.

onnx/defs/tensor_proto_util.cc

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -27,45 +27,47 @@ namespace ONNX_NAMESPACE {
2727
return t; \
2828
}
2929

30-
inline bool is_platform_little_endian() {
31-
int num = 1;
32-
if (*(char*)&num == 1)
33-
return true;
34-
return false;
35-
}
30+
inline bool is_platform_little_endian() {
31+
int num = 1;
32+
if (*(char*)&num == 1)
33+
return true;
34+
return false;
35+
}
3636

37-
#define DEFINE_PARSE_RAW_DATA(type) \
37+
#define DEFINE_PARSE_DATA(type, typed_data_fetch) \
3838
template <> \
39-
const std::vector<type> ParseRawData(const TensorProto* tensor_proto) { \
39+
const std::vector<type> ParseData(const TensorProto* tensor_proto) { \
4040
std::vector<type> res; \
41-
if (!tensor_proto->has_raw_data()) \
41+
if (!tensor_proto->has_raw_data()) { \
42+
const auto& data = tensor_proto->typed_data_fetch(); \
43+
res.insert(res.end(), data.begin(), data.end()); \
4244
return res; \
45+
} \
4346
/* make copy as we may have to reverse bytes */ \
4447
std::string raw_data = tensor_proto->raw_data(); \
4548
/* okay to remove const qualifier as we have already made a copy */ \
4649
char* bytes = const_cast<char*>(raw_data.c_str()); \
4750
/*onnx is little endian serialized always-tweak byte order if needed*/ \
4851
if (!is_platform_little_endian()) { \
49-
const size_t element_size = sizeof(type); \
50-
const size_t num_elements = raw_data.size() / element_size; \
51-
for (size_t i = 0; i < num_elements; ++i) { \
52-
char* start_byte = bytes + i * element_size; \
53-
char* end_byte = start_byte + element_size - 1; \
54-
/* keep swapping */ \
55-
for (size_t count = 0; count < element_size / 2; ++count) { \
56-
char temp = *start_byte; \
57-
*start_byte = *end_byte; \
58-
*end_byte = temp; \
59-
++start_byte; \
60-
--end_byte; \
61-
} \
52+
const size_t element_size = sizeof(type); \
53+
const size_t num_elements = raw_data.size() / element_size; \
54+
for (size_t i = 0; i < num_elements; ++i) { \
55+
char* start_byte = bytes + i * element_size; \
56+
char* end_byte = start_byte + element_size - 1; \
57+
/* keep swapping */ \
58+
for (size_t count = 0; count < element_size / 2; ++count) { \
59+
char temp = *start_byte; \
60+
*start_byte = *end_byte; \
61+
*end_byte = temp; \
62+
++start_byte; \
63+
--end_byte; \
6264
} \
65+
} \
6366
} \
6467
res.insert( \
6568
res.end(), \
6669
reinterpret_cast<const type*>(bytes), \
67-
reinterpret_cast<const type*>(bytes + raw_data.size()) \
68-
); \
70+
reinterpret_cast<const type*>(bytes + raw_data.size())); \
6971
return res; \
7072
}
7173

@@ -85,8 +87,8 @@ DEFINE_TO_TENSOR_LIST(uint64_t, TensorProto_DataType_UINT64, uint64)
8587
DEFINE_TO_TENSOR_LIST(double, TensorProto_DataType_DOUBLE, double)
8688
DEFINE_TO_TENSOR_LIST(std::string, TensorProto_DataType_STRING, string)
8789

88-
DEFINE_PARSE_RAW_DATA(int32_t)
89-
DEFINE_PARSE_RAW_DATA(int64_t)
90-
DEFINE_PARSE_RAW_DATA(float)
90+
DEFINE_PARSE_DATA(int32_t, int32_data)
91+
DEFINE_PARSE_DATA(int64_t, int64_data)
92+
DEFINE_PARSE_DATA(float, float_data)
9193

9294
} // namespace ONNX_NAMESPACE

onnx/defs/tensor_proto_util.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,6 @@ template <typename T>
1414
TensorProto ToTensor(const std::vector<T>& values);
1515

1616
template <typename T>
17-
const std::vector<T> ParseRawData(const TensorProto* tensor_proto);
17+
const std::vector<T> ParseData(const TensorProto* tensor_proto);
1818

1919
} // namespace ONNX_NAMESPACE

onnx/test/shape_inference_test.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,19 @@ def test_upsample(self): # type: () -> None
248248
[make_tensor_value_info('y', TensorProto.INT32, (2, 4, 3, 9))],
249249
opset_imports=[helper.make_opsetid("", 9)])
250250

251+
def test_upsample_raw_data(self): # type: () -> None
252+
graph = self._make_graph(
253+
[('x', TensorProto.INT32, (2, 4, 3, 5)),
254+
('scales', TensorProto.FLOAT, (4,))],
255+
[make_node("Upsample", ['x', 'scales'], ['y'])],
256+
[],
257+
initializer=[make_tensor('scales', TensorProto.FLOAT, (4,),
258+
vals=np.array([1.0, 1.1, 1.3, 1.9], dtype='<f4').tobytes(), raw=True)]) # Feed raw bytes (force little endian ordering like onnx standard) for test purpose
259+
self._assert_inferred(
260+
graph,
261+
[make_tensor_value_info('y', TensorProto.INT32, (2, 4, 3, 9))],
262+
opset_imports=[helper.make_opsetid("", 9)])
263+
251264
def test_resize(self): # type: () -> None
252265
graph = self._make_graph(
253266
[('x', TensorProto.INT32, (2, 4, 3, 5)),
@@ -786,6 +799,27 @@ def test_topk(self): # type: () -> None
786799
[make_tensor_value_info('y', TensorProto.FLOAT, (3, 4, 2, 10)),
787800
make_tensor_value_info('z', TensorProto.INT64, (3, 4, 2, 10))])
788801

802+
def test_topk_raw_data(self): # type: () -> None
803+
graph = self._make_graph(
804+
[('x', TensorProto.FLOAT, (3, 4, 5, 10))],
805+
[make_node('TopK', ['x', 'k'], ['y', 'z'], axis=2)],
806+
[],
807+
initializer=[make_tensor('k', TensorProto.INT64, (1,),
808+
vals=np.array([3], dtype='<i8').tobytes(), raw=True)]) # Feed raw bytes (force little endian ordering like onnx standard) for test purpose
809+
self._assert_inferred(graph,
810+
[make_tensor_value_info('y', TensorProto.FLOAT, (3, 4, 3, 10)),
811+
make_tensor_value_info('z', TensorProto.INT64, (3, 4, 3, 10))])
812+
813+
def test_topk_missing_k_value_output_rank_check(self): # type: () -> None
814+
graph = self._make_graph(
815+
[('x', TensorProto.FLOAT, (3, 4, 5, 10)),
816+
('k', TensorProto.INT64, (1,))],
817+
[make_node('TopK', ['x', 'k'], ['y', 'z'], axis=2)],
818+
[])
819+
self._assert_inferred(graph,
820+
[make_tensor_value_info('y', TensorProto.FLOAT, (None, None, None, None)), # type: ignore
821+
make_tensor_value_info('z', TensorProto.INT64, (None, None, None, None))]) # type: ignore
822+
789823
def test_gemm(self): # type: () -> None
790824
graph = self._make_graph(
791825
[('x', TensorProto.FLOAT, (7, 5)),

0 commit comments

Comments
 (0)