Skip to content

Commit 3c5e396

Browse files
yaeldMSfacebook-github-bot
authored andcommitted
[ONNX] Squeeze operator should give an error when trying to apply to a dimension with shape > 1 (#38476)
Summary: The ONNX spec for the Squeeze operator: > Remove single-dimensional entries from the shape of a tensor. Takes a parameter axes with a list of axes to squeeze. If axes is not provided, all the single dimensions will be removed from the shape. If an axis is selected with shape entry not equal to one, an error is raised. Currently, as explained in issue #36796, it is possible to export such a model to ONNX, and this results in an exception from ONNX runtime. Fixes #36796. Pull Request resolved: #38476 Reviewed By: hl475 Differential Revision: D22158024 Pulled By: houseroad fbshipit-source-id: bed625f3c626eabcbfb2ea83ec2f992963defa19
1 parent cd96dfd commit 3c5e396

7 files changed

Lines changed: 142 additions & 35 deletions

File tree

test/onnx/test_pytorch_onnx_onnxruntime.py

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -686,13 +686,64 @@ def forward(self, input1, input2, input3):
686686
self.run_test(TraceModel(), (x1, x2, x3), atol=10e-5)
687687
self.run_test(ScriptModel(), (x1, x2, x3), atol=10e-5)
688688

689-
def test_squeeze(self):
689+
def squeeze_model_tests(self, d, x1, x2):
690690
class Squeeze(torch.nn.Module):
691691
def forward(self, x):
692-
return torch.torch.squeeze(x, dim=-2)
692+
if d is not None:
693+
return torch.squeeze(x, dim=d)
694+
else:
695+
return torch.squeeze(x)
693696

697+
x2 = [] if x2 is None else [x2]
698+
self.run_test(Squeeze(), x1, input_names=['input'], dynamic_axes={'input': {0: '0', 1: '1', 2: '2'}}, test_with_inputs=x2)
699+
700+
def test_squeeze_without_no_op(self):
694701
x = torch.randn(2, 1, 4)
695-
self.run_test(Squeeze(), x)
702+
self.squeeze_model_tests(1, x, None)
703+
704+
@skipIfUnsupportedMinOpsetVersion(11)
705+
def test_squeeze(self):
706+
x_squeeze = torch.randn(2, 1, 4)
707+
x_noop = torch.randn(2, 2, 3)
708+
self.squeeze_model_tests(1, x_squeeze, x_noop)
709+
710+
def test_squeeze_neg_without_no_op(self):
711+
x = torch.randn(2, 1, 4)
712+
self.squeeze_model_tests(-2, x, None)
713+
714+
@skipIfUnsupportedMinOpsetVersion(11)
715+
def test_squeeze_neg(self):
716+
x_squeeze = torch.randn(2, 1, 4)
717+
x_noop = torch.randn(2, 2, 3)
718+
self.squeeze_model_tests(-2, x_squeeze, x_noop)
719+
720+
def test_squeeze_all_dims(self):
721+
x_squeeze = torch.randn(2, 1, 4)
722+
x_noop = torch.randn(2, 2, 3)
723+
self.squeeze_model_tests(None, x_squeeze, x_noop)
724+
725+
@skipIfUnsupportedMinOpsetVersion(11)
726+
def test_squeeze_no_op(self):
727+
x_noop = torch.randn(2, 1, 4)
728+
x_squeeze = torch.randn(2, 2, 1)
729+
self.squeeze_model_tests(2, x_noop, x_squeeze)
730+
731+
def test_squeeze_no_op_without_additional_inputs(self):
732+
x_noop = torch.randn(2, 1, 4)
733+
self.squeeze_model_tests(2, x_noop, None)
734+
735+
@skipIfUnsupportedMinOpsetVersion(11)
736+
def test_squeeze_runtime_dim(self):
737+
class Squeeze(torch.nn.Module):
738+
def forward(self, d1, d2):
739+
t = torch.zeros(d1[0], d2[0])
740+
return t.squeeze(0)
741+
742+
d1 = torch.tensor([1])
743+
d3 = torch.tensor([3])
744+
d4 = torch.tensor([4])
745+
self.run_test(Squeeze(), (d1, d4), test_with_inputs=[(d3, d4)])
746+
self.run_test(Squeeze(), (d3, d4), test_with_inputs=[(d1, d3)])
696747

697748
def test_unsqueeze(self):
698749
class Unsqueeze(torch.nn.Module):
@@ -1568,6 +1619,22 @@ def forward(self, input):
15681619
x = torch.randn(3, 4, 5, requires_grad=True)
15691620
self.run_test(IndexCopyModel(), x)
15701621

1622+
def test_select(self):
1623+
class Select(torch.nn.Module):
1624+
def forward(self, x):
1625+
return x[:, 1]
1626+
1627+
x = torch.randn(3, 4)
1628+
self.run_test(Select(), x)
1629+
1630+
def test_select_negative_index(self):
1631+
class Select(torch.nn.Module):
1632+
def forward(self, x):
1633+
return x[:, -1]
1634+
1635+
x = torch.randn(3, 4)
1636+
self.run_test(Select(), x)
1637+
15711638
# TODO: enable for opset 10 when ONNXRuntime version will be updated
15721639

15731640
def test_index_select_constant_scaler_index(self):

torch/csrc/jit/passes/onnx/helper.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,5 +49,15 @@ void buildParamsMapFromValueToParamsMap(
4949
paramsDict.insert(nameTensorParamPair.second);
5050
}
5151
}
52+
53+
Node* addNodeToBlock(Block* block, Value* input, Symbol kind) {
54+
auto new_node = block->appendNode(block->owningGraph()->create(kind));
55+
auto new_input = new_node->addInput(input);
56+
for (size_t i = 0; i < new_node->outputs().size(); i++) {
57+
auto output = new_node->outputs()[i];
58+
block->registerOutput(output);
59+
}
60+
return new_node;
61+
}
5262
} // namespace jit
5363
} // namespace torch

torch/csrc/jit/passes/onnx/helper.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ void eraseUnusedBlockInputs(Block* b);
2626
void buildParamsMapFromValueToParamsMap(
2727
const ValueToParamPairMap& valsToParamsMap,
2828
ParamMap& paramsDict);
29+
Node* addNodeToBlock(Block* block, Value* input, Symbol kind);
2930

3031
} // namespace jit
3132
} // namespace torch

torch/csrc/jit/python/python_ir.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <torch/csrc/jit/ir/alias_analysis.h>
55
#include <torch/csrc/jit/ir/ir.h>
66
#include <torch/csrc/jit/passes/canonicalize.h>
7+
#include <torch/csrc/jit/passes/onnx/helper.h>
78
#include <torch/csrc/jit/passes/shape_analysis.h>
89
#include <torch/csrc/jit/python/pybind.h>
910
#include <torch/csrc/jit/python/python_tracer.h>
@@ -13,7 +14,6 @@
1314
#include <torch/csrc/python_headers.h>
1415
#include <torch/csrc/utils/pybind.h>
1516
#include <torch/csrc/utils/python_strings.h>
16-
1717
#include <iostream>
1818
#include <sstream>
1919

@@ -467,7 +467,10 @@ void initPythonIRBindings(PyObject* module_) {
467467
return py::make_iterator(b.outputs().begin(), b.outputs().end());
468468
})
469469
.def("returnNode", [](Block& b) { return b.return_node(); })
470-
.def("paramNode", [](Block& b) { return b.param_node(); });
470+
.def("paramNode", [](Block& b) { return b.param_node(); })
471+
.def("addNode", [](Block& b, Value& input, const char* str) {
472+
return addNodeToBlock(&b, &input, Symbol::fromQualString(str));
473+
});
471474

472475
#define NS(name) def(#name, &Node ::name)
473476
py::class_<Node, std::unique_ptr<Node, py::nodelete>>(m, "Node")

torch/onnx/symbolic_opset11.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -497,14 +497,21 @@ def size(g, self, dim=None):
497497

498498
def squeeze(g, self, dim=None):
499499
if dim is None:
500-
dims = []
501-
for i, size in enumerate(self.type().sizes()):
502-
if size == 1:
503-
dims.append(i)
504-
else:
505-
dims = [sym_help._get_const(dim, 'i', 'dim')]
506-
return g.op("Squeeze", self, axes_i=dims)
507-
500+
return g.op("Squeeze", self)
501+
502+
dim = sym_help._get_const(dim, 'i', 'dim')
503+
504+
# create 'cond' node (condition is shape[i]==1)
505+
dim_constant = g.op("Constant", value_t=torch.tensor([dim]))
506+
size = sym_help._size_helper(g, self, dim_constant)
507+
const_one = g.op("Constant", value_t=torch.ones(1, dtype=torch.int64))
508+
cond = g.op("Equal", size, const_one)
509+
# create the 'If' node and add the 'then' and 'else' blocks to it.
510+
if_node_outputs = g.op("If", cond)
511+
if_node = if_node_outputs.node()
512+
torch.onnx.utils._add_block(if_node, self, "onnx::Squeeze", axes_i=[dim])
513+
torch.onnx.utils._add_block(if_node, self, "onnx::Identity")
514+
return if_node_outputs
508515

509516
@parse_args('v', 'i')
510517
def unsqueeze(g, self, dim):

torch/onnx/symbolic_opset9.py

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -559,29 +559,42 @@ def select(g, self, dim, index):
559559

560560
def squeeze(g, self, dim=None):
561561
if dim is None:
562-
dims = []
563-
for i, size in enumerate(self.type().sizes()):
564-
if size == 1:
565-
dims.append(i)
566-
else:
567-
dims = [sym_help._get_const(dim, 'i', 'dim')]
568-
# Handle negative dims
569-
for i, dim in enumerate(dims):
570-
if dim < 0:
571-
rank = self.type().dim()
572-
if rank:
573-
warnings.warn("ONNX export squeeze with negative axis " + str(dim) +
574-
" might cause the onnx model to be incorrect. " +
575-
"Negative axis is not supported in ONNX. " +
576-
"Axis is converted to " + str(dim + rank) +
577-
" based on input shape at export time. " +
578-
"Passing an tensor of different rank in execution will be incorrect.")
579-
dims[i] += rank
580-
else:
581-
return _unimplemented('squeeze', 'negative axis with unknown input rank')
582-
583-
return g.op("Squeeze", self, axes_i=dims)
562+
return g.op("Squeeze", self)
563+
564+
squeeze_dim = sym_help._get_const(dim, 'i', 'dim')
565+
# Handle negative dims
566+
if squeeze_dim < 0:
567+
rank = self.type().dim()
568+
if rank:
569+
warnings.warn("ONNX export squeeze with negative axis " + str(squeeze_dim) +
570+
" might cause the onnx model to be incorrect. " +
571+
"Negative axis is not supported in ONNX. " +
572+
"Axis is converted to " + str(squeeze_dim + rank) +
573+
" based on input shape at export time. " +
574+
"Passing an tensor of different rank in execution will be incorrect.")
575+
squeeze_dim += rank
576+
else:
577+
return _unimplemented('squeeze', 'negative axis with unknown input rank')
578+
579+
input_shape = self.type().sizes()
580+
if input_shape is None:
581+
warnings.warn("This model contains a squeeze operation on dimension " + str(squeeze_dim) + " on an input " +
582+
"with unknown shape. Note that if the size of dimension " + str(squeeze_dim) + " of the input " +
583+
"is not 1, the ONNX model will return an error. Opset version 11 supports squeezing on " +
584+
"non-singleton dimensions, it is recommended to export this model using opset " +
585+
"version 11 or higher.")
586+
return g.op("Squeeze", self, axes_i=[squeeze_dim])
587+
if input_shape[squeeze_dim] > 1:
588+
warnings.warn("This model contains a squeeze operation on dimension " + str(squeeze_dim) + ". The size of " +
589+
"this dimension in the given input is " + str(input_shape[squeeze_dim]) + ". The model will " +
590+
"be exported without the squeeze node. If the model is intended to be used with dynamic " +
591+
"input shapes, please use opset version 11 to " +
592+
"export the model.")
593+
return self
584594

595+
warnings.warn("This model contains a squeeze operation on dimension " + str(squeeze_dim) + ". If the model is " +
596+
"intended to be used with dynamic input shapes, please use opset version 11 to export the model.")
597+
return g.op("Squeeze", self, axes_i=[squeeze_dim])
585598

586599
def prelu(g, self, weight):
587600
if self.isCompleteTensor():

torch/onnx/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -986,6 +986,12 @@ def _validate_dynamic_axes(dynamic_axes, model, input_names, output_names):
986986
value_dict[x] = str(key) + '_dynamic_axes_' + str(i + 1)
987987
dynamic_axes[key] = value_dict
988988

989+
def _add_block(node, input_node, op_name, **kwargs):
990+
new_block = node.addBlock()
991+
new_node = new_block.addNode(input_node, op_name)
992+
for k, v in kwargs.items():
993+
_add_attribute(new_node, k, v, False)
994+
989995
torch._C.Graph.op = _graph_op
990996
torch._C.Graph.at = _graph_at
991997
torch._C.Graph.constant = _graph_constant

0 commit comments

Comments
 (0)