File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -222,6 +222,7 @@ struct GuardElimination {
222222 const static auto no_exceptions = std::unordered_set<size_t >{};
223223 switch (n->kind ()) {
224224 case aten::add:
225+ case aten::add_:
225226 case aten::sub:
226227 case aten::mul:
227228 case aten::div:
@@ -279,16 +280,32 @@ struct GuardElimination {
279280 case aten::addcmul:
280281 case aten::where:
281282 case aten::_cast_Float:
283+ case aten::_cast_Long:
282284 case aten::_sigmoid_backward:
283285 case aten::_tanh_backward:
284286 case aten::__and__:
285287 case aten::__or__:
286288 case aten::__xor__:
287289 case aten::__lshift__:
288290 case aten::__rshift__:
291+ case aten::bitwise_not:
292+ case aten::bitwise_and:
293+ case aten::bitwise_or:
294+ case aten::bitwise_xor:
289295 return checkInputs (n, no_exceptions, true );
296+ case aten::softmax:
297+ return checkInputs (n, std::unordered_set<size_t >{1 }, true );
298+ case aten::multinomial:
299+ return checkInputs (n, std::unordered_set<size_t >{2 , 3 }, false );
300+ case aten::flatten:
301+ case aten::argmax:
302+ case aten::squeeze:
290303 case aten::avg_pool2d:
291304 return checkInputs (n, no_exceptions, false );
305+ case aten::conv1d:
306+ case aten::conv2d:
307+ case aten::conv3d:
308+ return checkInputs (n, std::unordered_set<size_t >{2 , 6 }, false );
292309 case aten::slice:
293310 return !n->input (0 )->type ()->expect <TensorType>()->isSummarized () &&
294311 // check that the dimension argument is constant
You can’t perform that action at this time.
0 commit comments