1- #include < memory>
2- #include < torch/csrc/jit/runtime/graph_executor.h>
3- #include < torch/csrc/jit/jit_log.h>
41#include < torch/csrc/jit/ir/alias_analysis.h>
2+ #include < torch/csrc/jit/jit_log.h>
53#include < torch/csrc/jit/passes/constant_propagation.h>
64#include < torch/csrc/jit/passes/guard_elimination.h>
75#include < torch/csrc/jit/passes/peephole.h>
6+ #include < torch/csrc/jit/runtime/graph_executor.h>
7+ #include < memory>
88#include < unordered_set>
99
1010namespace torch {
1111namespace jit {
1212
1313struct GuardElimination {
1414 GuardElimination (std::shared_ptr<Graph> graph)
15- : graph_(std::move(graph)),
16- aliasDb_ (std::make_unique<AliasDb>(graph_)) {}
15+ : graph_(std::move(graph)), aliasDb_(std::make_unique<AliasDb>(graph_)) {}
1716
1817 void run () {
1918 const size_t MAX_ATTEMPTS = 5 ;
@@ -123,8 +122,11 @@ struct GuardElimination {
123122 auto it = guard;
124123 while (it != output) {
125124 if (it->kind () != prim::Guard && it->kind () != prim::Constant) {
126- GRAPH_DEBUG (" found an unexpected node " , *it,
127- " while trying to eliminate " , *guard);
125+ GRAPH_DEBUG (
126+ " found an unexpected node " ,
127+ *it,
128+ " while trying to eliminate " ,
129+ *guard);
128130 return false ;
129131 }
130132 it = it->prev ();
@@ -160,7 +162,10 @@ struct GuardElimination {
160162 // `checkInputs` check the invariants specified in `removableGuard`
161163 // on inputs to `n`. The invariants must hold, or an input must
162164 // be a `prim::Constant` or be included as an exception in `except`
163- bool checkInputs (Node *n, const std::unordered_set<size_t > &except, bool allow_numbers) {
165+ bool checkInputs (
166+ Node* n,
167+ const std::unordered_set<size_t >& except,
168+ bool allow_numbers) {
164169 bool all_inputs_guarded = true ;
165170 size_t i = 0 ;
166171 for (auto input : n->inputs ()) {
@@ -173,8 +178,11 @@ struct GuardElimination {
173178 input->node ()->kind () != prim::Guard ||
174179 input->type ()->expect <TensorType>());
175180 } else {
176- GRAPH_DEBUG (" input " , input->debugName (), " isn't guarded, type " ,
177- *input->type ());
181+ GRAPH_DEBUG (
182+ " input " ,
183+ input->debugName (),
184+ " isn't guarded, type " ,
185+ *input->type ());
178186 all_inputs_guarded = false ;
179187 break ;
180188 }
@@ -183,7 +191,7 @@ struct GuardElimination {
183191 return all_inputs_guarded;
184192 }
185193
186- private:
194+ private:
187195 // `removableGuard` relies on the properties checked by `isSummarized()`
188196 // and passes shouldn't insert nodes between a guard and its uses that
189197 // may alter those properties.
@@ -210,154 +218,161 @@ struct GuardElimination {
210218 // Guards can be removed if all inputs are guarded and `isSummarized()`
211219 // returns
212220 // false or inputs are `prim::Constant`
213- bool removableGuard (Node *n) {
214-
221+ bool removableGuard (Node* n) {
215222 const static auto no_exceptions = std::unordered_set<size_t >{};
216223 switch (n->kind ()) {
217- case aten::add:
218- case aten::sub:
219- case aten::mul:
220- case aten::div:
221- case aten::t:
222- case aten::sigmoid:
223- case aten::sin:
224- case aten::cos:
225- case aten::tan:
226- case aten::sinh:
227- case aten::cosh:
228- case aten::tanh:
229- case aten::asin:
230- case aten::acos:
231- case aten::atan:
232- case aten::atan2:
233- case aten::floor:
234- case aten::fmod:
235- case aten::ceil:
236- case aten::trunc:
237- case aten::sqrt:
238- case aten::rsqrt:
239- case aten::remainder:
240- case aten::mm:
241- case aten::min:
242- case aten::max:
243- case aten::type_as:
244- case aten::ge:
245- case aten::gt:
246- case aten::lt:
247- case aten::le:
248- case aten::eq:
249- case aten::ne:
250- case aten::neg:
251- case prim::ConstantChunk:
252- case aten::size:
253- case aten::abs:
254- case aten::sign:
255- case aten::pow:
256- case aten::relu:
257- case aten::threshold:
258- case prim::AutogradAdd:
259- case prim::AutogradZero:
260- case aten::rand_like:
261- case aten::erf:
262- case aten::erfc:
263- case aten::exp:
264- case aten::expm1:
265- case aten::log:
266- case aten::log2:
267- case aten::log10:
268- case aten::frac:
269- case aten::lerp:
270- case aten::lgamma:
271- case aten::reciprocal:
272- case aten::addcmul:
273- case aten::where:
274- return checkInputs (n, no_exceptions, true );
275- case aten::avg_pool2d:
276- return checkInputs (n, no_exceptions, false );
277- case aten::slice:
278- return !n->input (0 )->type ()->expect <TensorType>()->isSummarized () &&
279- // check that the dimension argument is constant
280- n->input (1 )->node ()->kind () == prim::Constant &&
281- // the start offset is constant
282- n->input (2 )->node ()->kind () == prim::Constant &&
283- // the end offset is constant
284- n->input (3 )->node ()->kind () == prim::Constant &&
285- // the stride is constant
286- n->input (4 )->node ()->kind () == prim::Constant;
287- case aten::max_pool1d:
288- case aten::max_pool2d:
289- case aten::max_pool3d:
290- return !n->input (0 )->type ()->expect <TensorType>()->isSummarized () &&
291- // check that the kernel size is constant
292- n->input (1 )->node ()->kind () == prim::Constant &&
293- // check that the stride is constant
294- n->input (2 )->node ()->kind () == prim::Constant &&
295- // check that the padding is constant
296- n->input (3 )->node ()->kind () == prim::Constant &&
297- // check that the dilation is constant
298- n->input (4 )->node ()->kind () == prim::Constant &&
299- // check that the ceil_mode is constant
300- n->input (5 )->node ()->kind () == prim::Constant;
301- case aten::unsqueeze:
302- // check that the dimension argument is constant
303- return !n->input (0 )->type ()->expect <TensorType>()->isSummarized () &&
224+ case aten::add:
225+ case aten::sub:
226+ case aten::mul:
227+ case aten::div:
228+ case aten::t:
229+ case aten::sigmoid:
230+ case aten::sin:
231+ case aten::cos:
232+ case aten::tan:
233+ case aten::sinh:
234+ case aten::cosh:
235+ case aten::tanh:
236+ case aten::asin:
237+ case aten::acos:
238+ case aten::atan:
239+ case aten::atan2:
240+ case aten::floor:
241+ case aten::fmod:
242+ case aten::ceil:
243+ case aten::trunc:
244+ case aten::sqrt:
245+ case aten::rsqrt:
246+ case aten::remainder:
247+ case aten::mm:
248+ case aten::min:
249+ case aten::max:
250+ case aten::type_as:
251+ case aten::ge:
252+ case aten::gt:
253+ case aten::lt:
254+ case aten::le:
255+ case aten::eq:
256+ case aten::ne:
257+ case aten::neg:
258+ case prim::ConstantChunk:
259+ case aten::size:
260+ case aten::abs:
261+ case aten::sign:
262+ case aten::pow:
263+ case aten::relu:
264+ case aten::threshold:
265+ case prim::AutogradAdd:
266+ case prim::AutogradZero:
267+ case aten::rand_like:
268+ case aten::erf:
269+ case aten::erfc:
270+ case aten::exp:
271+ case aten::expm1:
272+ case aten::log:
273+ case aten::log2:
274+ case aten::log10:
275+ case aten::frac:
276+ case aten::lerp:
277+ case aten::lgamma:
278+ case aten::reciprocal:
279+ case aten::addcmul:
280+ case aten::where:
281+ case aten::_cast_Float:
282+ case aten::_sigmoid_backward:
283+ case aten::_tanh_backward:
284+ case aten::__and__:
285+ case aten::__or__:
286+ case aten::__xor__:
287+ case aten::__lshift__:
288+ case aten::__rshift__:
289+ return checkInputs (n, no_exceptions, true );
290+ case aten::avg_pool2d:
291+ return checkInputs (n, no_exceptions, false );
292+ case aten::slice:
293+ return !n->input (0 )->type ()->expect <TensorType>()->isSummarized () &&
294+ // check that the dimension argument is constant
295+ n->input (1 )->node ()->kind () == prim::Constant &&
296+ // the start offset is constant
297+ n->input (2 )->node ()->kind () == prim::Constant &&
298+ // the end offset is constant
299+ n->input (3 )->node ()->kind () == prim::Constant &&
300+ // the stride is constant
301+ n->input (4 )->node ()->kind () == prim::Constant;
302+ case aten::max_pool1d:
303+ case aten::max_pool2d:
304+ case aten::max_pool3d:
305+ return !n->input (0 )->type ()->expect <TensorType>()->isSummarized () &&
306+ // check that the kernel size is constant
307+ n->input (1 )->node ()->kind () == prim::Constant &&
308+ // check that the stride is constant
309+ n->input (2 )->node ()->kind () == prim::Constant &&
310+ // check that the padding is constant
311+ n->input (3 )->node ()->kind () == prim::Constant &&
312+ // check that the dilation is constant
313+ n->input (4 )->node ()->kind () == prim::Constant &&
314+ // check that the ceil_mode is constant
315+ n->input (5 )->node ()->kind () == prim::Constant;
316+ case aten::unsqueeze:
317+ // check that the dimension argument is constant
318+ return !n->input (0 )->type ()->expect <TensorType>()->isSummarized () &&
304319 n->input (1 )->node ()->kind () == prim::Constant;
305- case aten::cat:
306- // check that the dimension argument is constant
307- return n->input (1 )->node ()->kind () == prim::Constant &&
308- n->input (0 )->node ()->kind () == prim::ListConstruct &&
309- // no extra nodes in between aten::cat and prim::ListConstruct
310- n->prev () == n->input (0 )->node () &&
311- // check the inputs to prim::ListConstruct (not aten::cat)
312- checkInputs (n->input (0 )->node (), no_exceptions, false );
313- case aten::clamp:
314- // the second and third args do not affect shapes
315- return checkInputs (n, std::unordered_set<size_t >{1 , 2 }, false );
316- // after some optimizations we might end up with two Guards back-to-back
317- // which case we can remove the one whose input is also prim::Guard
318- case aten::_grad_sum_to_size:
319- // skip checking size argument
320- if (checkInputs (n, std::unordered_set<size_t >{1 }, false )) {
321- auto asize = n->input (1 )->node ();
322- if (asize->kind () == prim::Constant) {
323- return true ;
324- } else if (asize->matches (" aten::size(Tensor self) -> int[]" )) {
325- // aten::size is effectively a constant
326- if (asize->input ()
327- ->type ()
328- ->expect <TensorType>()
329- ->sizes ()
330- .concrete_sizes ()) {
320+ case aten::cat:
321+ // check that the dimension argument is constant
322+ return n->input (1 )->node ()->kind () == prim::Constant &&
323+ n->input (0 )->node ()->kind () == prim::ListConstruct &&
324+ // no extra nodes in between aten::cat and prim::ListConstruct
325+ n->prev () == n->input (0 )->node () &&
326+ // check the inputs to prim::ListConstruct (not aten::cat)
327+ checkInputs (n->input (0 )->node (), no_exceptions, false );
328+ case aten::clamp:
329+ // the second and third args do not affect shapes
330+ return checkInputs (n, std::unordered_set<size_t >{1 , 2 }, false );
331+ // after some optimizations we might end up with two Guards back-to-back
332+ // which case we can remove the one whose input is also prim::Guard
333+ case aten::_grad_sum_to_size:
334+ // skip checking size argument
335+ if (checkInputs (n, std::unordered_set<size_t >{1 }, false )) {
336+ auto asize = n->input (1 )->node ();
337+ if (asize->kind () == prim::Constant) {
331338 return true ;
339+ } else if (asize->matches (" aten::size(Tensor self) -> int[]" )) {
340+ // aten::size is effectively a constant
341+ if (asize->input ()
342+ ->type ()
343+ ->expect <TensorType>()
344+ ->sizes ()
345+ .concrete_sizes ()) {
346+ return true ;
347+ }
332348 }
333349 }
334- }
335- return false ;
336-
337- // this is checked by one of the tests in test_jit_fuser.py
338- case prim::ListUnpack: {
339- // check if the input is a constant chunk
340- // used for LSTM fusions
341- auto chunk = n->input (0 )->node ();
342- if (chunk->kind () != aten::chunk) {
343350 return false ;
351+
352+ // this is checked by one of the tests in test_jit_fuser.py
353+ case prim::ListUnpack: {
354+ // check if the input is a constant chunk
355+ // used for LSTM fusions
356+ auto chunk = n->input (0 )->node ();
357+ if (chunk->kind () != aten::chunk) {
358+ return false ;
359+ }
360+ return checkInputs (chunk, no_exceptions, false );
344361 }
345- return checkInputs (chunk, no_exceptions, false );
346- }
347- // this is checked by one of the tests in test_jit_fuser.py
348- case aten::broadcast_tensors: {
349- auto list_construct = n-> input ( 0 )-> node () ;
350- if (list_construct-> kind () != prim::ListConstruct) {
351- return false ;
362+ // this is checked by one of the tests in test_jit_fuser.py
363+ case aten::broadcast_tensors: {
364+ auto list_construct = n-> input ( 0 )-> node ();
365+ if (list_construct-> kind () != prim::ListConstruct) {
366+ return false ;
367+ }
368+ return checkInputs (list_construct, no_exceptions, false ) ;
352369 }
353- return checkInputs (list_construct, no_exceptions, false );
354- }
355- case prim::Guard:
356- case prim::GradOf:
357- return true ;
358- default :
359- GRAPH_DEBUG (" cannot remove " , n->kind ().toQualString ());
360- return false ;
370+ case prim::Guard:
371+ case prim::GradOf:
372+ return true ;
373+ default :
374+ GRAPH_DEBUG (" cannot remove " , n->kind ().toQualString ());
375+ return false ;
361376 }
362377 }
363378
@@ -366,7 +381,6 @@ struct GuardElimination {
366381 static std::unordered_set<Symbol> simple_ops_;
367382};
368383
369-
370384void EliminateRedundantGuards (std::shared_ptr<Graph> graph) {
371385 GuardElimination ge (std::move (graph));
372386 ge.run ();
0 commit comments