Skip to content

Commit 414285b

Browse files
committed
fix the buffer overflow problem in shape inference logic of Squeeze op
1 parent 797cdd0 commit 414285b

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

onnx/defs/tensor/defs.cc

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1265,9 +1265,16 @@ ONNX_OPERATOR_SET_SCHEMA(
12651265
ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
12661266
const auto& input_shape = ctx.getInputType(0)->tensor_type().shape();
12671267
const auto input_ndim = input_shape.dim_size();
1268+
std::transform(
1269+
axes.begin(),
1270+
axes.end(),
1271+
axes.begin(),
1272+
[&](int64_t axis) -> int64_t {
1273+
return axis < 0 ? axis + input_ndim : axis;
1274+
});
1275+
12681276
for (int i = 0, j = 0; i < input_ndim; ++i) {
1269-
auto axis_j = axes[j] < 0 ? axes[j] + input_ndim : axes[j];
1270-
if (static_cast<size_t>(j) < axes.size() && axis_j == i) {
1277+
if (std::find(axes.begin(), axes.end(), i) != axes.end()) {
12711278
if (input_shape.dim(i).has_dim_value() &&
12721279
input_shape.dim(i).dim_value() != 1) {
12731280
fail_shape_inference(

0 commit comments

Comments
 (0)