File tree Expand file tree Collapse file tree 1 file changed +9
-2
lines changed
Expand file tree Collapse file tree 1 file changed +9
-2
lines changed Original file line number Diff line number Diff line change @@ -1265,9 +1265,16 @@ ONNX_OPERATOR_SET_SCHEMA(
12651265 ctx.getOutputType (0 )->mutable_tensor_type ()->mutable_shape ();
12661266 const auto & input_shape = ctx.getInputType (0 )->tensor_type ().shape ();
12671267 const auto input_ndim = input_shape.dim_size ();
1268+ std::transform (
1269+ axes.begin (),
1270+ axes.end (),
1271+ axes.begin (),
1272+ [&](int64_t axis) -> int64_t {
1273+ return axis < 0 ? axis + input_ndim : axis;
1274+ });
1275+
12681276 for (int i = 0 , j = 0 ; i < input_ndim; ++i) {
1269- auto axis_j = axes[j] < 0 ? axes[j] + input_ndim : axes[j];
1270- if (static_cast <size_t >(j) < axes.size () && axis_j == i) {
1277+ if (std::find (axes.begin (), axes.end (), i) != axes.end ()) {
12711278 if (input_shape.dim (i).has_dim_value () &&
12721279 input_shape.dim (i).dim_value () != 1 ) {
12731280 fail_shape_inference (
You can’t perform that action at this time.
0 commit comments