Skip to content

Commit a31cb90

Browse files
committed
Add guard elimination cases for operators encountered on an RL workload.
1 parent fd57e09 commit a31cb90

1 file changed

Lines changed: 17 additions & 0 deletions

File tree

torch/csrc/jit/passes/guard_elimination.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ struct GuardElimination {
222222
const static auto no_exceptions = std::unordered_set<size_t>{};
223223
switch (n->kind()) {
224224
case aten::add:
225+
case aten::add_:
225226
case aten::sub:
226227
case aten::mul:
227228
case aten::div:
@@ -279,16 +280,32 @@ struct GuardElimination {
279280
case aten::addcmul:
280281
case aten::where:
281282
case aten::_cast_Float:
283+
case aten::_cast_Long:
282284
case aten::_sigmoid_backward:
283285
case aten::_tanh_backward:
284286
case aten::__and__:
285287
case aten::__or__:
286288
case aten::__xor__:
287289
case aten::__lshift__:
288290
case aten::__rshift__:
291+
case aten::bitwise_not:
292+
case aten::bitwise_and:
293+
case aten::bitwise_or:
294+
case aten::bitwise_xor:
289295
return checkInputs(n, no_exceptions, true);
296+
case aten::softmax:
297+
return checkInputs(n, std::unordered_set<size_t>{1}, true);
298+
case aten::multinomial:
299+
return checkInputs(n, std::unordered_set<size_t>{2, 3}, false);
300+
case aten::flatten:
301+
case aten::argmax:
302+
case aten::squeeze:
290303
case aten::avg_pool2d:
291304
return checkInputs(n, no_exceptions, false);
305+
case aten::conv1d:
306+
case aten::conv2d:
307+
case aten::conv3d:
308+
return checkInputs(n, std::unordered_set<size_t>{2, 6}, false);
292309
case aten::slice:
293310
return !n->input(0)->type()->expect<TensorType>()->isSummarized() &&
294311
// check that the dimension argument is constant

0 commit comments

Comments
 (0)