Skip to content

Commit fe55bda

Browse files
committed
port ge changes from bert/pytorch_fusion
1 parent 95f1cb3 commit fe55bda

1 file changed

Lines changed: 165 additions & 151 deletions

File tree

torch/csrc/jit/passes/guard_elimination.cpp

Lines changed: 165 additions & 151 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,18 @@
1-
#include <memory>
2-
#include <torch/csrc/jit/runtime/graph_executor.h>
3-
#include <torch/csrc/jit/jit_log.h>
41
#include <torch/csrc/jit/ir/alias_analysis.h>
2+
#include <torch/csrc/jit/jit_log.h>
53
#include <torch/csrc/jit/passes/constant_propagation.h>
64
#include <torch/csrc/jit/passes/guard_elimination.h>
75
#include <torch/csrc/jit/passes/peephole.h>
6+
#include <torch/csrc/jit/runtime/graph_executor.h>
7+
#include <memory>
88
#include <unordered_set>
99

1010
namespace torch {
1111
namespace jit {
1212

1313
struct GuardElimination {
1414
GuardElimination(std::shared_ptr<Graph> graph)
15-
: graph_(std::move(graph)),
16-
aliasDb_(std::make_unique<AliasDb>(graph_)) {}
15+
: graph_(std::move(graph)), aliasDb_(std::make_unique<AliasDb>(graph_)) {}
1716

1817
void run() {
1918
const size_t MAX_ATTEMPTS = 5;
@@ -123,8 +122,11 @@ struct GuardElimination {
123122
auto it = guard;
124123
while (it != output) {
125124
if (it->kind() != prim::Guard && it->kind() != prim::Constant) {
126-
GRAPH_DEBUG("found an unexpected node ", *it,
127-
" while trying to eliminate ", *guard);
125+
GRAPH_DEBUG(
126+
"found an unexpected node ",
127+
*it,
128+
" while trying to eliminate ",
129+
*guard);
128130
return false;
129131
}
130132
it = it->prev();
@@ -160,7 +162,10 @@ struct GuardElimination {
160162
// `checkInputs` check the invariants specified in `removableGuard`
161163
// on inputs to `n`. The invariants must hold, or an input must
162164
// be a `prim::Constant` or be included as an exception in `except`
163-
bool checkInputs(Node *n, const std::unordered_set<size_t> &except, bool allow_numbers) {
165+
bool checkInputs(
166+
Node* n,
167+
const std::unordered_set<size_t>& except,
168+
bool allow_numbers) {
164169
bool all_inputs_guarded = true;
165170
size_t i = 0;
166171
for (auto input : n->inputs()) {
@@ -173,8 +178,11 @@ struct GuardElimination {
173178
input->node()->kind() != prim::Guard ||
174179
input->type()->expect<TensorType>());
175180
} else {
176-
GRAPH_DEBUG("input ", input->debugName(), " isn't guarded, type ",
177-
*input->type());
181+
GRAPH_DEBUG(
182+
"input ",
183+
input->debugName(),
184+
" isn't guarded, type ",
185+
*input->type());
178186
all_inputs_guarded = false;
179187
break;
180188
}
@@ -183,7 +191,7 @@ struct GuardElimination {
183191
return all_inputs_guarded;
184192
}
185193

186-
private:
194+
private:
187195
// `removableGuard` relies on the properties checked by `isSummarized()`
188196
// and passes shouldn't insert nodes between a guard and its uses that
189197
// may alter those properties.
@@ -210,154 +218,161 @@ struct GuardElimination {
210218
// Guards can be removed if all inputs are guarded and `isSummarized()`
211219
// returns
212220
// false or inputs are `prim::Constant`
213-
bool removableGuard(Node *n) {
214-
221+
bool removableGuard(Node* n) {
215222
const static auto no_exceptions = std::unordered_set<size_t>{};
216223
switch (n->kind()) {
217-
case aten::add:
218-
case aten::sub:
219-
case aten::mul:
220-
case aten::div:
221-
case aten::t:
222-
case aten::sigmoid:
223-
case aten::sin:
224-
case aten::cos:
225-
case aten::tan:
226-
case aten::sinh:
227-
case aten::cosh:
228-
case aten::tanh:
229-
case aten::asin:
230-
case aten::acos:
231-
case aten::atan:
232-
case aten::atan2:
233-
case aten::floor:
234-
case aten::fmod:
235-
case aten::ceil:
236-
case aten::trunc:
237-
case aten::sqrt:
238-
case aten::rsqrt:
239-
case aten::remainder:
240-
case aten::mm:
241-
case aten::min:
242-
case aten::max:
243-
case aten::type_as:
244-
case aten::ge:
245-
case aten::gt:
246-
case aten::lt:
247-
case aten::le:
248-
case aten::eq:
249-
case aten::ne:
250-
case aten::neg:
251-
case prim::ConstantChunk:
252-
case aten::size:
253-
case aten::abs:
254-
case aten::sign:
255-
case aten::pow:
256-
case aten::relu:
257-
case aten::threshold:
258-
case prim::AutogradAdd:
259-
case prim::AutogradZero:
260-
case aten::rand_like:
261-
case aten::erf:
262-
case aten::erfc:
263-
case aten::exp:
264-
case aten::expm1:
265-
case aten::log:
266-
case aten::log2:
267-
case aten::log10:
268-
case aten::frac:
269-
case aten::lerp:
270-
case aten::lgamma:
271-
case aten::reciprocal:
272-
case aten::addcmul:
273-
case aten::where:
274-
return checkInputs(n, no_exceptions, true);
275-
case aten::avg_pool2d:
276-
return checkInputs(n, no_exceptions, false);
277-
case aten::slice:
278-
return !n->input(0)->type()->expect<TensorType>()->isSummarized() &&
279-
// check that the dimension argument is constant
280-
n->input(1)->node()->kind() == prim::Constant &&
281-
// the start offset is constant
282-
n->input(2)->node()->kind() == prim::Constant &&
283-
// the end offset is constant
284-
n->input(3)->node()->kind() == prim::Constant &&
285-
// the stride is constant
286-
n->input(4)->node()->kind() == prim::Constant;
287-
case aten::max_pool1d:
288-
case aten::max_pool2d:
289-
case aten::max_pool3d:
290-
return !n->input(0)->type()->expect<TensorType>()->isSummarized() &&
291-
// check that the kernel size is constant
292-
n->input(1)->node()->kind() == prim::Constant &&
293-
// check that the stride is constant
294-
n->input(2)->node()->kind() == prim::Constant &&
295-
// check that the padding is constant
296-
n->input(3)->node()->kind() == prim::Constant &&
297-
// check that the dilation is constant
298-
n->input(4)->node()->kind() == prim::Constant &&
299-
// check that the ceil_mode is constant
300-
n->input(5)->node()->kind() == prim::Constant;
301-
case aten::unsqueeze:
302-
// check that the dimension argument is constant
303-
return !n->input(0)->type()->expect<TensorType>()->isSummarized() &&
224+
case aten::add:
225+
case aten::sub:
226+
case aten::mul:
227+
case aten::div:
228+
case aten::t:
229+
case aten::sigmoid:
230+
case aten::sin:
231+
case aten::cos:
232+
case aten::tan:
233+
case aten::sinh:
234+
case aten::cosh:
235+
case aten::tanh:
236+
case aten::asin:
237+
case aten::acos:
238+
case aten::atan:
239+
case aten::atan2:
240+
case aten::floor:
241+
case aten::fmod:
242+
case aten::ceil:
243+
case aten::trunc:
244+
case aten::sqrt:
245+
case aten::rsqrt:
246+
case aten::remainder:
247+
case aten::mm:
248+
case aten::min:
249+
case aten::max:
250+
case aten::type_as:
251+
case aten::ge:
252+
case aten::gt:
253+
case aten::lt:
254+
case aten::le:
255+
case aten::eq:
256+
case aten::ne:
257+
case aten::neg:
258+
case prim::ConstantChunk:
259+
case aten::size:
260+
case aten::abs:
261+
case aten::sign:
262+
case aten::pow:
263+
case aten::relu:
264+
case aten::threshold:
265+
case prim::AutogradAdd:
266+
case prim::AutogradZero:
267+
case aten::rand_like:
268+
case aten::erf:
269+
case aten::erfc:
270+
case aten::exp:
271+
case aten::expm1:
272+
case aten::log:
273+
case aten::log2:
274+
case aten::log10:
275+
case aten::frac:
276+
case aten::lerp:
277+
case aten::lgamma:
278+
case aten::reciprocal:
279+
case aten::addcmul:
280+
case aten::where:
281+
case aten::_cast_Float:
282+
case aten::_sigmoid_backward:
283+
case aten::_tanh_backward:
284+
case aten::__and__:
285+
case aten::__or__:
286+
case aten::__xor__:
287+
case aten::__lshift__:
288+
case aten::__rshift__:
289+
return checkInputs(n, no_exceptions, true);
290+
case aten::avg_pool2d:
291+
return checkInputs(n, no_exceptions, false);
292+
case aten::slice:
293+
return !n->input(0)->type()->expect<TensorType>()->isSummarized() &&
294+
// check that the dimension argument is constant
295+
n->input(1)->node()->kind() == prim::Constant &&
296+
// the start offset is constant
297+
n->input(2)->node()->kind() == prim::Constant &&
298+
// the end offset is constant
299+
n->input(3)->node()->kind() == prim::Constant &&
300+
// the stride is constant
301+
n->input(4)->node()->kind() == prim::Constant;
302+
case aten::max_pool1d:
303+
case aten::max_pool2d:
304+
case aten::max_pool3d:
305+
return !n->input(0)->type()->expect<TensorType>()->isSummarized() &&
306+
// check that the kernel size is constant
307+
n->input(1)->node()->kind() == prim::Constant &&
308+
// check that the stride is constant
309+
n->input(2)->node()->kind() == prim::Constant &&
310+
// check that the padding is constant
311+
n->input(3)->node()->kind() == prim::Constant &&
312+
// check that the dilation is constant
313+
n->input(4)->node()->kind() == prim::Constant &&
314+
// check that the ceil_mode is constant
315+
n->input(5)->node()->kind() == prim::Constant;
316+
case aten::unsqueeze:
317+
// check that the dimension argument is constant
318+
return !n->input(0)->type()->expect<TensorType>()->isSummarized() &&
304319
n->input(1)->node()->kind() == prim::Constant;
305-
case aten::cat:
306-
// check that the dimension argument is constant
307-
return n->input(1)->node()->kind() == prim::Constant &&
308-
n->input(0)->node()->kind() == prim::ListConstruct &&
309-
// no extra nodes in between aten::cat and prim::ListConstruct
310-
n->prev() == n->input(0)->node() &&
311-
// check the inputs to prim::ListConstruct (not aten::cat)
312-
checkInputs(n->input(0)->node(), no_exceptions, false);
313-
case aten::clamp:
314-
// the second and third args do not affect shapes
315-
return checkInputs(n, std::unordered_set<size_t>{1, 2}, false);
316-
// after some optimizations we might end up with two Guards back-to-back
317-
// which case we can remove the one whose input is also prim::Guard
318-
case aten::_grad_sum_to_size:
319-
// skip checking size argument
320-
if (checkInputs(n, std::unordered_set<size_t>{1}, false)) {
321-
auto asize = n->input(1)->node();
322-
if (asize->kind() == prim::Constant) {
323-
return true;
324-
} else if (asize->matches("aten::size(Tensor self) -> int[]")) {
325-
// aten::size is effectively a constant
326-
if (asize->input()
327-
->type()
328-
->expect<TensorType>()
329-
->sizes()
330-
.concrete_sizes()) {
320+
case aten::cat:
321+
// check that the dimension argument is constant
322+
return n->input(1)->node()->kind() == prim::Constant &&
323+
n->input(0)->node()->kind() == prim::ListConstruct &&
324+
// no extra nodes in between aten::cat and prim::ListConstruct
325+
n->prev() == n->input(0)->node() &&
326+
// check the inputs to prim::ListConstruct (not aten::cat)
327+
checkInputs(n->input(0)->node(), no_exceptions, false);
328+
case aten::clamp:
329+
// the second and third args do not affect shapes
330+
return checkInputs(n, std::unordered_set<size_t>{1, 2}, false);
331+
// after some optimizations we might end up with two Guards back-to-back
332+
// which case we can remove the one whose input is also prim::Guard
333+
case aten::_grad_sum_to_size:
334+
// skip checking size argument
335+
if (checkInputs(n, std::unordered_set<size_t>{1}, false)) {
336+
auto asize = n->input(1)->node();
337+
if (asize->kind() == prim::Constant) {
331338
return true;
339+
} else if (asize->matches("aten::size(Tensor self) -> int[]")) {
340+
// aten::size is effectively a constant
341+
if (asize->input()
342+
->type()
343+
->expect<TensorType>()
344+
->sizes()
345+
.concrete_sizes()) {
346+
return true;
347+
}
332348
}
333349
}
334-
}
335-
return false;
336-
337-
// this is checked by one of the tests in test_jit_fuser.py
338-
case prim::ListUnpack: {
339-
// check if the input is a constant chunk
340-
// used for LSTM fusions
341-
auto chunk = n->input(0)->node();
342-
if (chunk->kind() != aten::chunk) {
343350
return false;
351+
352+
// this is checked by one of the tests in test_jit_fuser.py
353+
case prim::ListUnpack: {
354+
// check if the input is a constant chunk
355+
// used for LSTM fusions
356+
auto chunk = n->input(0)->node();
357+
if (chunk->kind() != aten::chunk) {
358+
return false;
359+
}
360+
return checkInputs(chunk, no_exceptions, false);
344361
}
345-
return checkInputs(chunk, no_exceptions, false);
346-
}
347-
// this is checked by one of the tests in test_jit_fuser.py
348-
case aten::broadcast_tensors: {
349-
auto list_construct = n->input(0)->node();
350-
if (list_construct->kind() != prim::ListConstruct) {
351-
return false;
362+
// this is checked by one of the tests in test_jit_fuser.py
363+
case aten::broadcast_tensors: {
364+
auto list_construct = n->input(0)->node();
365+
if (list_construct->kind() != prim::ListConstruct) {
366+
return false;
367+
}
368+
return checkInputs(list_construct, no_exceptions, false);
352369
}
353-
return checkInputs(list_construct, no_exceptions, false);
354-
}
355-
case prim::Guard:
356-
case prim::GradOf:
357-
return true;
358-
default:
359-
GRAPH_DEBUG("cannot remove ", n->kind().toQualString());
360-
return false;
370+
case prim::Guard:
371+
case prim::GradOf:
372+
return true;
373+
default:
374+
GRAPH_DEBUG("cannot remove ", n->kind().toQualString());
375+
return false;
361376
}
362377
}
363378

@@ -366,7 +381,6 @@ struct GuardElimination {
366381
static std::unordered_set<Symbol> simple_ops_;
367382
};
368383

369-
370384
void EliminateRedundantGuards(std::shared_ptr<Graph> graph) {
371385
GuardElimination ge(std::move(graph));
372386
ge.run();

0 commit comments

Comments
 (0)