Skip to content

Commit ea694bf

Browse files
ArmenAghouseroad
authored andcommitted
implement fuse reduce->unsqueeze + fix assumption in nop_dropout pass (#1565)
* implement fuse reduce->unsqueeze + fix assumption in nop_dropout pass * fix bugs * remove nop code * correct output shape calculation * fix linting issue
1 parent 6db386e commit ea694bf

File tree

6 files changed

+129
-4
lines changed

6 files changed

+129
-4
lines changed

onnx/common/interned_strings.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ namespace ONNX_NAMESPACE {
7979
_(ratio) \
8080
_(size) \
8181
_(dim) \
82+
_(keepdims) \
8283
_(perm) \
8384
_(shape) \
8485
_(axes) \
@@ -145,7 +146,17 @@ namespace ONNX_NAMESPACE {
145146
_(__control_inputs) \
146147
_(count_include_pad) \
147148
_(storage_order) \
148-
_(Unsqueeze)
149+
_(Unsqueeze) \
150+
_(ReduceL1) \
151+
_(ReduceL2) \
152+
_(ReduceLogSum) \
153+
_(ReduceLogSumExp) \
154+
_(ReduceMax) \
155+
_(ReduceMean) \
156+
_(ReduceMin) \
157+
_(ReduceProd) \
158+
_(ReduceSum) \
159+
_(ReduceSumSquare)
149160

150161
enum BuiltinSymbol {
151162
#define DEFINE_SYMBOL(s) k##s,

onnx/examples/optimize_onnx.ipynb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
"\tfuse_bn_into_conv\n",
6161
"\tfuse_consecutive_concats\n",
6262
"\tfuse_consecutive_log_softmax\n",
63+
"\tfuse_consecutive_reduce_unsqueeze\n",
6364
"\tfuse_consecutive_squeezes\n",
6465
"\tfuse_consecutive_transposes\n",
6566
"\tfuse_transpose_into_gemm\n",
@@ -120,7 +121,7 @@
120121
"name": "python",
121122
"nbconvert_exporter": "python",
122123
"pygments_lexer": "ipython3",
123-
"version": "3.6.4"
124+
"version": "3.7.1"
124125
}
125126
},
126127
"nbformat": 4,

onnx/optimizer/pass_registry.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
#include "onnx/common/ir_pb_converter.h"
88
#include "onnx/common/stl_backports.h"
99
#include "onnx/optimizer/passes/eliminate_deadend.h"
10-
#include "onnx/optimizer/passes/eliminate_nop_dropout.h"
1110
#include "onnx/optimizer/passes/eliminate_identity.h"
11+
#include "onnx/optimizer/passes/eliminate_nop_dropout.h"
1212
#include "onnx/optimizer/passes/eliminate_nop_monotone_argmax.h"
1313
#include "onnx/optimizer/passes/eliminate_nop_pad.h"
1414
#include "onnx/optimizer/passes/eliminate_nop_transpose.h"
@@ -18,6 +18,7 @@
1818
#include "onnx/optimizer/passes/fuse_bn_into_conv.h"
1919
#include "onnx/optimizer/passes/fuse_consecutive_concats.h"
2020
#include "onnx/optimizer/passes/fuse_consecutive_log_softmax.h"
21+
#include "onnx/optimizer/passes/fuse_consecutive_reduce_unsqueeze.h"
2122
#include "onnx/optimizer/passes/fuse_consecutive_squeezes.h"
2223
#include "onnx/optimizer/passes/fuse_consecutive_transposes.h"
2324
#include "onnx/optimizer/passes/fuse_transpose_into_gemm.h"
@@ -51,6 +52,7 @@ struct GlobalPassRegistry {
5152
registerPass<FuseBNIntoConv>();
5253
registerPass<FuseConsecutiveConcats>();
5354
registerPass<FuseConsecutiveLogSoftmax>();
55+
registerPass<FuseConsecutiveReduceUnsqueeze>();
5456
registerPass<FuseConsecutiveSqueezes>();
5557
registerPass<FuseConsecutiveTransposes>();
5658
registerPass<FuseTransposeIntoGemm>();

onnx/optimizer/passes/eliminate_nop_dropout.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@ struct EliminateNopDropout final : public PredicateBasedPass {
2626

2727
bool runTransform(Node* node, Graph&, NodeDestroyType& destroy_current)
2828
override {
29-
node->output()->replaceAllUsesWith(node->input());
29+
// Don't assume that theres only one output.
30+
for (size_t i = 0; i < node->outputs().size(); ++i) {
31+
node->outputs()[i]->replaceAllUsesWith(node->input());
32+
}
3033
destroy_current = NodeDestroyType::DestroyOne;
3134
return true;
3235
}
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// ATTENTION: The code in this file is highly EXPERIMENTAL.
2+
// Adventurous users should note that the APIs will probably change.
3+
4+
#pragma once
5+
6+
#include "onnx/optimizer/pass.h"
7+
8+
namespace ONNX_NAMESPACE {
9+
namespace optimization {
10+
11+
const std::unordered_set<NodeKind> reduction_operators{kReduceL1,
12+
kReduceL2,
13+
kReduceLogSum,
14+
kReduceLogSumExp,
15+
kReduceMax,
16+
kReduceMean,
17+
kReduceMin,
18+
kReduceProd,
19+
kReduceSum,
20+
kReduceSumSquare};
21+
22+
struct FuseConsecutiveReduceUnsqueeze final : public PredicateBasedPass {
23+
explicit FuseConsecutiveReduceUnsqueeze()
24+
: PredicateBasedPass(
25+
PassType::Fuse,
26+
PassEfficiency::Complete,
27+
PassOptimizationType::Compute) {}
28+
29+
std::string getPassName() const override {
30+
return "fuse_consecutive_reduce_unsqueeze";
31+
}
32+
bool patternMatchPredicate(Node* node) override {
33+
// check that the current node is of type Unsqueeze and has defined axes
34+
bool cur_node_check =
35+
node->kind() == kUnsqueeze && node->hasAttribute(kaxes);
36+
if (cur_node_check) {
37+
Node* prev_node = node->input()->node();
38+
// check that the previous node a reduction operator and has defined
39+
// axes/keepdims
40+
bool reduction_node_check = reduction_operators.find(prev_node->kind()) !=
41+
reduction_operators.end() &&
42+
prev_node->hasAttribute(kaxes) && prev_node->hasAttribute(kkeepdims);
43+
if (reduction_node_check) {
44+
// insure that keepdims is set to false currently
45+
return prev_node->i(kkeepdims) == 0 && node->is(kaxes) == prev_node->is(kaxes);
46+
}
47+
}
48+
return false;
49+
}
50+
bool runTransform(Node* node, Graph&, NodeDestroyType& destroy_current)
51+
override {
52+
Node* reduction_op = node->input()->node();
53+
// set keepdims flag to be true
54+
reduction_op->i_(kkeepdims, 1);
55+
// remove unnecessary unsqueeze
56+
reduction_op->output()->setSizes(node->output()->sizes());
57+
reduction_op->output()->setElemType(node->output()->elemType());
58+
node->output()->replaceAllUsesWith(node->input());
59+
destroy_current = NodeDestroyType::DestroyOne;
60+
return true;
61+
}
62+
};
63+
64+
} // namespace optimization
65+
} // namespace ONNX_NAMESPACE

onnx/test/optimizer_test.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,6 +1120,49 @@ def test_eliminate_nop_dropout(self): # type: () -> None
11201120
assert len(optimized_model.graph.node) == 1
11211121
assert optimized_model.graph.node[0].op_type == "Log"
11221122

1123+
def test_fuse_reduction_unsqueeze(self): # type: () -> None
1124+
def _calculate_post_transform_shape(input_shape, reduction_axes, unsqueeze_axes, keepdim): # type: (Tuple[int, ...], List[int], List[int], bool) -> Tuple[int, ...]
1125+
post_reduce_shape = None
1126+
if keepdim:
1127+
post_reduce_shape = tuple([(x if i not in reduction_axes else 1) for i, x in enumerate(input_shape)])
1128+
else:
1129+
post_reduce_shape = tuple([x for i, x in enumerate(input_shape) if i not in reduction_axes])
1130+
post_unsqueeze_shape = list(post_reduce_shape)
1131+
for ax in unsqueeze_axes:
1132+
post_unsqueeze_shape.insert(ax, 1)
1133+
return tuple(post_unsqueeze_shape)
1134+
1135+
for reduction in ["ReduceL1", "ReduceL2", "ReduceLogSum",
1136+
"ReduceLogSumExp", "ReduceMax", "ReduceMean",
1137+
"ReduceMin", "ReduceProd", "ReduceSum", "ReduceSumSquare"]:
1138+
for axes1 in [[1], [1, 2], [2]]:
1139+
for axes2 in [[1], [1, 2], [2]]:
1140+
for keepdim in [False, True]:
1141+
input_shape = (5, 7, 9)
1142+
output_shape = _calculate_post_transform_shape(input_shape, axes1, axes2, keepdim) # type: Tuple[int, ...]
1143+
node = helper.make_node(reduction, ["X"], ["Y"], axes=axes1, keepdims=keepdim)
1144+
node1 = helper.make_node("Unsqueeze", ["Y"], ["Z"], axes=axes2)
1145+
graph = helper.make_graph(
1146+
[node, node1],
1147+
"test",
1148+
[helper.make_tensor_value_info(
1149+
"X", TensorProto.FLOAT, input_shape)],
1150+
[helper.make_tensor_value_info("Z", TensorProto.FLOAT, output_shape)])
1151+
optimized_model = self._optimized(
1152+
graph, ["fuse_consecutive_reduce_unsqueeze"], False)
1153+
1154+
if keepdim or axes1 != axes2:
1155+
assert optimized_model.graph == graph
1156+
else:
1157+
assert len(optimized_model.graph.output) == 1
1158+
assert len(optimized_model.graph.node) == 1
1159+
assert optimized_model.graph.output[0].type.tensor_type.elem_type == TensorProto.FLOAT
1160+
assert optimized_model.graph.node[-1].op_type == reduction
1161+
assert optimized_model.graph.node[-1].attribute[0].name == "axes"
1162+
assert optimized_model.graph.node[-1].attribute[0].ints == axes1
1163+
optimized_output_shape = tuple(x.dim_value for x in optimized_model.graph.output[0].type.tensor_type.shape.dim)
1164+
assert optimized_output_shape == output_shape
1165+
11231166

11241167
if __name__ == '__main__':
11251168
unittest.main()

0 commit comments

Comments
 (0)