File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -348,6 +348,11 @@ Operation createTensorExprOp(const Node* node) {
348348 std::make_shared<tensorexpr::TensorExprKernel>(node->g (attr::Subgraph));
349349 return [kernel](Stack& stack) {
350350 RECORD_FUNCTION (" TensorExpr" , std::vector<c10::IValue>());
351+ if (!tensorexpr::fallbackAllowed ()) {
352+ kernel->run (stack);
353+ return 0 ;
354+ }
355+
351356 try {
352357 kernel->run (stack);
353358 } catch (const std::runtime_error& e) {
Original file line number Diff line number Diff line change @@ -495,6 +495,8 @@ void initJITBindings(PyObject* module) {
495495 })
496496 .def (" _jit_set_texpr_fuser_enabled" , &setTensorExprFuserEnabled)
497497 .def (" _jit_texpr_fuser_enabled" , &tensorExprFuserEnabled)
498+ .def (" _jit_texpr_fallback_allowed" , &tensorexpr::fallbackAllowed)
499+ .def (" _jit_texpr_set_fallback_allowed" , &tensorexpr::setFallbackAllowed)
498500 .def (
499501 " _jit_pass_fuse_tensorexprs" ,
500502 [](std::shared_ptr<Graph>& g) { return FuseTensorExprs (g); })
Original file line number Diff line number Diff line change @@ -17,6 +17,24 @@ namespace tensorexpr {
1717static int te_cuda_pointwise_loop_levels = -1 ;
1818static int te_cuda_pointwise_block_count = -1 ;
1919static int te_cuda_pointwise_block_size = -1 ;
20+ static bool fallback_allowed = true ;
21+
22+ bool setFallbackAllowed (bool value) {
23+ bool old_value = fallback_allowed;
24+ fallback_allowed = value;
25+ return old_value;
26+ }
27+
28+ bool fallbackAllowed () {
29+ static const char * enable_c_str = std::getenv (" PYTORCH_TENSOREXPR_FALLBACK" );
30+ if (!enable_c_str) {
31+ return fallback_allowed;
32+ }
33+ if (std::string (enable_c_str) == " 0" ) {
34+ return false ;
35+ }
36+ return true ;
37+ }
2038
2139int & getTECudaPointwiseLoopLevels () {
2240 return te_cuda_pointwise_loop_levels;
@@ -1366,6 +1384,11 @@ void TensorExprKernel::compile() {
13661384
13671385TensorExprKernel::TensorExprKernel (const std::shared_ptr<Graph>& subgraph)
13681386 : graph_(subgraph), code_(subgraph, " " ) {
1387+ if (!fallbackAllowed ()) {
1388+ compile ();
1389+ return ;
1390+ }
1391+
13691392 try {
13701393 compile ();
13711394 } catch (...) {
@@ -1374,6 +1397,11 @@ TensorExprKernel::TensorExprKernel(const std::shared_ptr<Graph>& subgraph)
13741397}
13751398
13761399void TensorExprKernel::run (Stack& stack) {
1400+ if (!fallbackAllowed ()) {
1401+ runKernel (stack);
1402+ return ;
1403+ }
1404+
13771405 if (fallback_) {
13781406 fallback (stack);
13791407 return ;
Original file line number Diff line number Diff line change @@ -228,6 +228,8 @@ class TORCH_API TensorExprKernel {
228228TORCH_API int & getTECudaPointwiseLoopLevels ();
229229TORCH_API int & getTECudaPointwiseBlockCount ();
230230TORCH_API int & getTECudaPointwiseBlockSize ();
231+ TORCH_API bool fallbackAllowed ();
232+ TORCH_API bool setFallbackAllowed (bool value);
231233
232234} // namespace tensorexpr
233235} // namespace jit
You can’t perform that action at this time.
0 commit comments