Skip to content

Commit 2873fea

Browse files
hariharans29wschin
authored andcommitted
Fix spec and shape inference for Unsqueeze op (#2347)
* Fix spec for Unsqueeze * Update Changelog.md * Refine doc * Refine * PR comments * Update Changelog.md * Update shape inference test file
1 parent a3c9145 commit 2873fea

9 files changed

Lines changed: 157 additions & 23 deletions

File tree

docs/Changelog.md

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13803,11 +13803,18 @@ This version of the operator has been available since version 11 of the default
1380313803

1380413804
### <a name="Unsqueeze-11"></a>**Unsqueeze-11**</a>
1380513805

13806-
Insert single-dimensional entries to the shape of a tensor.
13807-
Takes one required argument `axes`, a list of dimensions that will be inserted.
13808-
Dimension indices in `axes` are as seen in the output tensor. For example:
13809-
Given a tensor such that tensor with shape [3, 4, 5], then
13810-
Unsqueeze(tensor, axes=[0, 4]) has shape [1, 3, 4, 5, 1]
13806+
Insert single-dimensional entries to the shape of an input tensor (`data`).
13807+
Takes one required argument `axes` - which contains a list of dimension indices and this operator will insert a dimension of value `1` into the corresponding index of the output tensor (`expanded`).
13808+
13809+
For example:
13810+
Given an input tensor (`data`) of shape [3, 4, 5], then
13811+
Unsqueeze(data, axes=[0, 4]) outputs a tensor (`expanded`) containing same data as `data` but with shape [1, 3, 4, 5, 1].
13812+
13813+
The attribute `axes` should not contain any duplicate entries. It is an error if it contains duplicates.
13814+
The rank of the output tensor (`output_rank`) is the rank of the input tensor (`data`) plus the number of values in `axes`.
13815+
Each value in `axes` should be within the (inclusive) range [-output_rank , output_rank - 1].
13816+
The order of values in `axes` does not matter and can come in any order.
13817+
1381113818

1381213819
#### Version
1381313820

@@ -13817,7 +13824,7 @@ This version of the operator has been available since version 11 of the default
1381713824

1381813825
<dl>
1381913826
<dt><tt>axes</tt> : list of ints (required)</dt>
13820-
<dd>List of integers indicating the dimensions to be inserted. Negative value means counting dimensions from the back. Accepted range is [-r, r-1] where r = rank(data).</dd>
13827+
<dd>List of integers indicating the dimensions to be inserted. Negative value means counting dimensions from the back. Accepted range is [-r, r-1] where r = rank(expanded).</dd>
1382113828
</dl>
1382213829

1382313830
#### Inputs

docs/Operators.md

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18530,11 +18530,18 @@ expect(node_sorted, inputs=[x], outputs=[y, indices, inverse_indices, counts], n
1853018530

1853118531
### <a name="Unsqueeze"></a><a name="unsqueeze">**Unsqueeze**</a>
1853218532

18533-
Insert single-dimensional entries to the shape of a tensor.
18534-
Takes one required argument `axes`, a list of dimensions that will be inserted.
18535-
Dimension indices in `axes` are as seen in the output tensor. For example:
18536-
Given a tensor such that tensor with shape [3, 4, 5], then
18537-
Unsqueeze(tensor, axes=[0, 4]) has shape [1, 3, 4, 5, 1]
18533+
Insert single-dimensional entries to the shape of an input tensor (`data`).
18534+
Takes one required argument `axes` - which contains a list of dimension indices and this operator will insert a dimension of value `1` into the corresponding index of the output tensor (`expanded`).
18535+
18536+
For example:
18537+
Given an input tensor (`data`) of shape [3, 4, 5], then
18538+
Unsqueeze(data, axes=[0, 4]) outputs a tensor (`expanded`) containing same data as `data` but with shape [1, 3, 4, 5, 1].
18539+
18540+
The attribute `axes` should not contain any duplicate entries. It is an error if it contains duplicates.
18541+
The rank of the output tensor (`output_rank`) is the rank of the input tensor (`data`) plus the number of values in `axes`.
18542+
Each value in `axes` should be within the (inclusive) range [-output_rank , output_rank - 1].
18543+
The order of values in `axes` does not matter and can come in any order.
18544+
1853818545

1853918546
#### Version
1854018547

@@ -18546,7 +18553,7 @@ Other versions of this operator: <a href="Changelog.md#Unsqueeze-1">Unsqueeze-1<
1854618553

1854718554
<dl>
1854818555
<dt><tt>axes</tt> : list of ints (required)</dt>
18549-
<dd>List of integers indicating the dimensions to be inserted. Negative value means counting dimensions from the back. Accepted range is [-r, r-1] where r = rank(data).</dd>
18556+
<dd>List of integers indicating the dimensions to be inserted. Negative value means counting dimensions from the back. Accepted range is [-r, r-1] where r = rank(expanded).</dd>
1855018557
</dl>
1855118558

1855218559
#### Inputs
@@ -18659,6 +18666,29 @@ expect(node, inputs=[x], outputs=[y],
1865918666
</details>
1866018667

1866118668

18669+
<details>
18670+
<summary>unsqueeze_unsorted_axes</summary>
18671+
18672+
```python
18673+
x = np.random.randn(3, 4, 5).astype(np.float32)
18674+
18675+
node = onnx.helper.make_node(
18676+
'Unsqueeze',
18677+
inputs=['x'],
18678+
outputs=['y'],
18679+
axes=[5, 4, 2],
18680+
)
18681+
y = np.expand_dims(x, axis=2)
18682+
y = np.expand_dims(y, axis=4)
18683+
y = np.expand_dims(y, axis=5)
18684+
18685+
expect(node, inputs=[x], outputs=[y],
18686+
name='test_unsqueeze_unsorted_axes')
18687+
```
18688+
18689+
</details>
18690+
18691+
1866218692
### <a name="Upsample"></a><a name="upsample">**Upsample** (deprecated)</a>
1866318693

1866418694
Upsample the input tensor.

docs/TestCoverage.md

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10358,7 +10358,7 @@ expect(node_sorted, inputs=[x], outputs=[y, indices, inverse_indices, counts], n
1035810358

1035910359

1036010360
### Unsqueeze
10361-
There are 4 test cases, listed as following:
10361+
There are 5 test cases, listed as following:
1036210362
<details>
1036310363
<summary>unsqueeze_negative_axes</summary>
1036410364

@@ -10436,6 +10436,27 @@ expect(node, inputs=[x], outputs=[y],
1043610436
name='test_unsqueeze_two_axes')
1043710437
```
1043810438

10439+
</details>
10440+
<details>
10441+
<summary>unsqueeze_unsorted_axes</summary>
10442+
10443+
```python
10444+
x = np.random.randn(3, 4, 5).astype(np.float32)
10445+
10446+
node = onnx.helper.make_node(
10447+
'Unsqueeze',
10448+
inputs=['x'],
10449+
outputs=['y'],
10450+
axes=[5, 4, 2],
10451+
)
10452+
y = np.expand_dims(x, axis=2)
10453+
y = np.expand_dims(y, axis=4)
10454+
y = np.expand_dims(y, axis=5)
10455+
10456+
expect(node, inputs=[x], outputs=[y],
10457+
name='test_unsqueeze_unsorted_axes')
10458+
```
10459+
1043910460
</details>
1044010461

1044110462

onnx/backend/test/case/node/unsqueeze.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,23 @@ def export_unsqueeze_three_axes(): # type: () -> None
6161
expect(node, inputs=[x], outputs=[y],
6262
name='test_unsqueeze_three_axes')
6363

64+
@staticmethod
65+
def export_unsqueeze_unsorted_axes(): # type: () -> None
66+
x = np.random.randn(3, 4, 5).astype(np.float32)
67+
68+
node = onnx.helper.make_node(
69+
'Unsqueeze',
70+
inputs=['x'],
71+
outputs=['y'],
72+
axes=[5, 4, 2],
73+
)
74+
y = np.expand_dims(x, axis=2)
75+
y = np.expand_dims(y, axis=4)
76+
y = np.expand_dims(y, axis=5)
77+
78+
expect(node, inputs=[x], outputs=[y],
79+
name='test_unsqueeze_unsorted_axes')
80+
6481
@staticmethod
6582
def export_unsqueeze_negative_axes(): # type: () -> None
6683
node = onnx.helper.make_node(
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
 backend-test:�
2+
"
3+
xy" Unsqueeze*
4+
axes@@@�test_unsqueeze_unsorted_axesZ
5+
x
6+

7+

8+

9+
b#
10+
y
11+

12+

13+

14+

15+

16+

17+
B
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
BxJ�x��?h��>��z?�j@$ �?�.z��8s?b��hdӽ�9�>(�>�%�?^�B?�0�= B�>]ת>�=�?R�iJ�>�Z�/d#��S'?�K]?��=��C@�(��Hm;= �?�2�?��?��>���>�Ec������!��� >*z�?��?�Oƾmǚ��6��&õ�gڿ��?�x�FKྙ[��� G?4�ο��Y�L=e��> �����k��QN�>.:�=�ݚ>�b"�6���
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
ByJ�x��?h��>��z?�j@$ �?�.z��8s?b��hdӽ�9�>(�>�%�?^�B?�0�= B�>]ת>�=�?R�iJ�>�Z�/d#��S'?�K]?��=��C@�(��Hm;= �?�2�?��?��>���>�Ec������!��� >*z�?��?�Oƾmǚ��6��&õ�gڿ��?�x�FKྙ[��� G?4�ο��Y�L=e��> �����k��QN�>.:�=�ݚ>�b"�6���

onnx/defs/tensor/defs.cc

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1301,11 +1301,18 @@ ONNX_OPERATOR_SET_SCHEMA(
13011301
}));
13021302

13031303
static const char* Unsqueeze_ver11_doc = R"DOC(
1304-
Insert single-dimensional entries to the shape of a tensor.
1305-
Takes one required argument `axes`, a list of dimensions that will be inserted.
1306-
Dimension indices in `axes` are as seen in the output tensor. For example:
1307-
Given a tensor such that tensor with shape [3, 4, 5], then
1308-
Unsqueeze(tensor, axes=[0, 4]) has shape [1, 3, 4, 5, 1]
1304+
Insert single-dimensional entries to the shape of an input tensor (`data`).
1305+
Takes one required argument `axes` - which contains a list of dimension indices and this operator will insert a dimension of value `1` into the corresponding index of the output tensor (`expanded`).
1306+
1307+
For example:
1308+
Given an input tensor (`data`) of shape [3, 4, 5], then
1309+
Unsqueeze(data, axes=[0, 4]) outputs a tensor (`expanded`) containing same data as `data` but with shape [1, 3, 4, 5, 1].
1310+
1311+
The attribute `axes` should not contain any duplicate entries. It is an error if it contains duplicates.
1312+
The rank of the output tensor (`output_rank`) is the rank of the input tensor (`data`) plus the number of values in `axes`.
1313+
Each value in `axes` should be within the (inclusive) range [-output_rank , output_rank - 1].
1314+
The order of values in `axes` does not matter and can come in any order.
1315+
13091316
)DOC";
13101317

13111318
ONNX_OPERATOR_SET_SCHEMA(
@@ -1315,7 +1322,7 @@ ONNX_OPERATOR_SET_SCHEMA(
13151322
.Attr(
13161323
"axes",
13171324
"List of integers indicating the dimensions to be inserted. Negative value means counting dimensions "
1318-
"from the back. Accepted range is [-r, r-1] where r = rank(data).",
1325+
"from the back. Accepted range is [-r, r-1] where r = rank(expanded).",
13191326
AttributeProto::INTS)
13201327
.SetDoc(Unsqueeze_ver11_doc)
13211328
.Input(0, "data", "Original tensor", "T")
@@ -1334,7 +1341,16 @@ ONNX_OPERATOR_SET_SCHEMA(
13341341
if (!getRepeatedAttribute(ctx, "axes", axes)) {
13351342
return;
13361343
}
1337-
std::sort(axes.begin(), axes.end());
1344+
1345+
// validate 'axes' for duplicate entries
1346+
std::unordered_set<int64_t> unique_values;
1347+
for (const auto val : axes) {
1348+
if (unique_values.find(val) != unique_values.end()) {
1349+
fail_shape_inference(
1350+
"'axes' attribute must not contain any duplicates");
1351+
}
1352+
unique_values.insert(val);
1353+
}
13381354

13391355
if (!ctx.getInputType(0)->tensor_type().has_shape()) {
13401356
return;
@@ -1343,10 +1359,20 @@ ONNX_OPERATOR_SET_SCHEMA(
13431359
ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
13441360
const auto& input_shape = ctx.getInputType(0)->tensor_type().shape();
13451361
const auto input_ndim = input_shape.dim_size();
1362+
const auto output_ndim = input_ndim + static_cast<int>(axes.size());
13461363
for (size_t(i) = 0; i < axes.size(); ++i) {
1347-
if (axes[i] < 0)
1348-
axes[i] += input_ndim;
1364+
if (axes[i] < -output_ndim || axes[i] >= output_ndim) {
1365+
fail_shape_inference(
1366+
"values in 'axes' are beyond the bounds of the computed output shape");
1367+
}
1368+
if (axes[i] < 0) {
1369+
axes[i] += output_ndim;
1370+
}
13491371
}
1372+
1373+
// sort after correcting negative axes values (if any) in the previous step
1374+
std::sort(axes.begin(), axes.end());
1375+
13501376
int j = 0;
13511377
for (int i = 0; i < input_ndim; ++i) {
13521378
while (static_cast<size_t>(j) < axes.size() &&

onnx/test/shape_inference_test.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,13 +455,27 @@ def test_squeeze(self): # type: () -> None
455455
[])
456456
self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.FLOAT, (3, 2))])
457457

458-
def test_unsqueeze(self): # type: () -> None
458+
def test_unsqueeze_regular(self): # type: () -> None
459459
graph = self._make_graph(
460460
[('x', TensorProto.FLOAT, (3, 2))],
461461
[make_node('Unsqueeze', 'x', 'y', axes=[0, 1, 3, 5])],
462462
[])
463463
self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.FLOAT, (1, 1, 3, 1, 2, 1))])
464464

465+
def test_unsqueeze_unsorted_axes(self): # type: () -> None
466+
graph = self._make_graph(
467+
[('x', TensorProto.FLOAT, (3, 4, 5))],
468+
[make_node('Unsqueeze', 'x', 'y', axes=[4, 0])],
469+
[])
470+
self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.FLOAT, (1, 3, 4, 5, 1))])
471+
472+
def test_unsqueeze_negative_axes(self): # type: () -> None
473+
graph = self._make_graph(
474+
[('x', TensorProto.FLOAT, (3, 4, 5))],
475+
[make_node('Unsqueeze', 'x', 'y', axes=[0, -1])],
476+
[])
477+
self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.FLOAT, (1, 3, 4, 5, 1))])
478+
465479
def test_slice_without_input_shape(self): # type: () -> None
466480
graph = self._make_graph(
467481
[('x', TensorProto.FLOAT, (3, 2)), ('starts', TensorProto.INT64, (1,)), ('ends', TensorProto.INT64, (1,))],

0 commit comments

Comments
 (0)