Skip to content

Commit 8aec4e2

Browse files
authored
[Anderspapitto patch] fix the shape inference for broadcasting (#1368)
* fix case in shape inference where dimension is unknown on both sides * fix case in shape inference where dimension is unknown on both sides (#1367)
1 parent 1b09eb1 commit 8aec4e2

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

onnx/defs/shape_inference.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,9 @@ inline void multidirectionalBroadcastShapeInference(
388388
resultShape.add_dim()->set_dim_value(dim_value);
389389
} else if (num_symbolic_dims == 1) {
390390
*resultShape.add_dim() = symbolic_dim;
391-
}
391+
} else {
392+
resultShape.add_dim();
393+
}
392394
}
393395
}
394396

onnx/defs/tensor/defs.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,10 @@ ONNX_OPERATOR_SET_SCHEMA(
300300
fail_shape_inference("Required attribute axis is missing");
301301
}
302302
int axis = static_cast<int>(axisAttr->i());
303-
if (axis < 0 || rank <= axis) {
303+
if (rank <= axis) {
304+
fail_shape_inference("rank must be greater than axis");
305+
}
306+
if (axis < 0) {
304307
return; // TODO: check if negative axis must be supported
305308
}
306309

0 commit comments

Comments
 (0)