Skip to content

Commit a44824c

Browse files
Mikhail Zolotukhinfacebook-github-bot
authored andcommitted
[TensorExpr] Allow to enable/disable fallback mechanism thru an envvar PYTORCH_TENSOREXPR_FALLBACK. (#37971)
Summary: Pull Request resolved: #37971 Test Plan: Imported from OSS Reviewed By: protonu Differential Revision: D21444831 Pulled By: ZolotukhinM fbshipit-source-id: c75f58772a4730e8f40f05491f9e5afa4aa3ed30
1 parent 067f08c commit a44824c

4 files changed

Lines changed: 37 additions & 0 deletions

File tree

torch/csrc/jit/passes/tensorexpr_fuser.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,11 @@ Operation createTensorExprOp(const Node* node) {
348348
std::make_shared<tensorexpr::TensorExprKernel>(node->g(attr::Subgraph));
349349
return [kernel](Stack& stack) {
350350
RECORD_FUNCTION("TensorExpr", std::vector<c10::IValue>());
351+
if (!tensorexpr::fallbackAllowed()) {
352+
kernel->run(stack);
353+
return 0;
354+
}
355+
351356
try {
352357
kernel->run(stack);
353358
} catch (const std::runtime_error& e) {

torch/csrc/jit/python/init.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,8 @@ void initJITBindings(PyObject* module) {
495495
})
496496
.def("_jit_set_texpr_fuser_enabled", &setTensorExprFuserEnabled)
497497
.def("_jit_texpr_fuser_enabled", &tensorExprFuserEnabled)
498+
.def("_jit_texpr_fallback_allowed", &tensorexpr::fallbackAllowed)
499+
.def("_jit_texpr_set_fallback_allowed", &tensorexpr::setFallbackAllowed)
498500
.def(
499501
"_jit_pass_fuse_tensorexprs",
500502
[](std::shared_ptr<Graph>& g) { return FuseTensorExprs(g); })

torch/csrc/jit/tensorexpr/kernel.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,24 @@ namespace tensorexpr {
1717
static int te_cuda_pointwise_loop_levels = -1;
1818
static int te_cuda_pointwise_block_count = -1;
1919
static int te_cuda_pointwise_block_size = -1;
20+
static bool fallback_allowed = true;
21+
22+
bool setFallbackAllowed(bool value) {
23+
bool old_value = fallback_allowed;
24+
fallback_allowed = value;
25+
return old_value;
26+
}
27+
28+
bool fallbackAllowed() {
29+
static const char* enable_c_str = std::getenv("PYTORCH_TENSOREXPR_FALLBACK");
30+
if (!enable_c_str) {
31+
return fallback_allowed;
32+
}
33+
if (std::string(enable_c_str) == "0") {
34+
return false;
35+
}
36+
return true;
37+
}
2038

2139
int& getTECudaPointwiseLoopLevels() {
2240
return te_cuda_pointwise_loop_levels;
@@ -1366,6 +1384,11 @@ void TensorExprKernel::compile() {
13661384

13671385
TensorExprKernel::TensorExprKernel(const std::shared_ptr<Graph>& subgraph)
13681386
: graph_(subgraph), code_(subgraph, "") {
1387+
if (!fallbackAllowed()) {
1388+
compile();
1389+
return;
1390+
}
1391+
13691392
try {
13701393
compile();
13711394
} catch (...) {
@@ -1374,6 +1397,11 @@ TensorExprKernel::TensorExprKernel(const std::shared_ptr<Graph>& subgraph)
13741397
}
13751398

13761399
void TensorExprKernel::run(Stack& stack) {
1400+
if (!fallbackAllowed()) {
1401+
runKernel(stack);
1402+
return;
1403+
}
1404+
13771405
if (fallback_) {
13781406
fallback(stack);
13791407
return;

torch/csrc/jit/tensorexpr/kernel.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,8 @@ class TORCH_API TensorExprKernel {
228228
TORCH_API int& getTECudaPointwiseLoopLevels();
229229
TORCH_API int& getTECudaPointwiseBlockCount();
230230
TORCH_API int& getTECudaPointwiseBlockSize();
231+
TORCH_API bool fallbackAllowed();
232+
TORCH_API bool setFallbackAllowed(bool value);
231233

232234
} // namespace tensorexpr
233235
} // namespace jit

0 commit comments

Comments
 (0)