Skip to content

Commit 6e13146

Browse files
Mikhail Zolotukhinfacebook-github-bot
authored andcommitted
[TensorExpr] TensorExprKernel: don't do any compilation or lowering in run(). (#37948)
Summary: Pull Request resolved: #37948 The input JIT graph has all the information we need to perform the entire compilation at the construction time. We don't need to postpone any steps until the execution time. Also, from the graph we always know what device we will be executing on and thus we don't need to have a CodeGen cache in TensorExprKernel - we always have one and only one CodeGen. Test Plan: Imported from OSS Reviewed By: protonu Differential Revision: D21432145 Pulled By: ZolotukhinM fbshipit-source-id: 8dc86b891713056b2c62f30170cd4a168912f027
1 parent eac54f1 commit 6e13146

3 files changed

Lines changed: 32 additions & 56 deletions

File tree

test/cpp/tensorexpr/test_kernel.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ void testKernel_1() {
3535
auto ref = a * (a * b);
3636
TensorExprKernel k(graph);
3737
std::vector<at::Tensor> inputs = {a, b};
38-
Stmt* s = k.getStmtForInputs(fmap<IValue>(inputs));
38+
Stmt* s = k.getCodeGenStmt();
3939
// TODO: verify stmt
4040

4141
std::vector<IValue> stack = fmap<IValue>(inputs);
@@ -65,7 +65,7 @@ void testKernel_2() {
6565
auto ref = a * (a * b);
6666
TensorExprKernel k(graph);
6767
std::vector<at::Tensor> inputs = {a, b};
68-
Stmt* s = k.getStmtForInputs(fmap<IValue>(inputs));
68+
Stmt* s = k.getCodeGenStmt();
6969
// TODO: verify stmt
7070

7171
std::vector<IValue> stack = fmap<IValue>(inputs);
@@ -95,7 +95,7 @@ void testKernel_3() {
9595
auto ref = a * (a * b);
9696
TensorExprKernel k(graph);
9797
std::vector<at::Tensor> inputs = {a, b};
98-
Stmt* s = k.getStmtForInputs(fmap<IValue>(inputs));
98+
Stmt* s = k.getCodeGenStmt();
9999
// TODO: verify stmt
100100

101101
std::vector<IValue> stack = fmap<IValue>(inputs);

torch/csrc/jit/tensorexpr/kernel.cpp

Lines changed: 23 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1158,7 +1158,7 @@ Stmt* TensorExprKernel::generateStmt(BackendType backendType) {
11581158
return stmt;
11591159
}
11601160

1161-
std::string TensorExprKernel::getCodegenName(BackendType backendType) {
1161+
std::string TensorExprKernel::getCodeGenName(BackendType backendType) {
11621162
switch (backendType) {
11631163
case kCudaCodeGen:
11641164
return "cuda_codegen";
@@ -1272,10 +1272,11 @@ static void checkInputs(
12721272
}
12731273

12741274
at::Device TensorExprKernel::pickDeviceType(
1275-
const at::ArrayRef<IValue>& inputs) {
1275+
const at::ArrayRef<torch::jit::Value*>& inputs) {
12761276
for (auto const& input : inputs) {
1277-
if (input.isTensor()) {
1278-
return input.toTensor().device();
1277+
auto tt = input->type()->cast<TensorType>();
1278+
if (tt && tt->device()) {
1279+
return *tt->device();
12791280
}
12801281
}
12811282
throw std::runtime_error("No tensor inputs");
@@ -1390,6 +1391,16 @@ void TensorExprKernel::compile() {
13901391
tensorOutputs_.emplace_back(tensors_.at(output->unique()));
13911392
tensors_.erase(output->unique());
13921393
}
1394+
1395+
device_ = pickDeviceType(graph_->inputs());
1396+
BackendType backendType = inferBackendTypeFromDevice(device_);
1397+
Stmt* stmt = generateStmt(backendType);
1398+
1399+
// Set up formal params (inputs, then outputs) for kernel.
1400+
std::vector<CodeGen::BufferArg> params = prepareBufferArgs();
1401+
1402+
// Generate code.
1403+
codegen_ = CreateCodeGen(getCodeGenName(backendType), stmt, params, device_);
13931404
}
13941405

13951406
TensorExprKernel::TensorExprKernel(const std::shared_ptr<Graph>& subgraph)
@@ -1426,8 +1437,7 @@ void TensorExprKernel::run(Stack& stack) {
14261437

14271438
std::vector<CodeGen::CallArg> TensorExprKernel::prepareRunArgs(
14281439
const at::ArrayRef<IValue>& inputs,
1429-
std::vector<at::Tensor>& outputs,
1430-
at::Device device) {
1440+
std::vector<at::Tensor>& outputs) {
14311441
std::map<const Expr*, int32_t> varToSize;
14321442

14331443
std::vector<CodeGen::CallArg> runArgs;
@@ -1468,57 +1478,27 @@ std::vector<CodeGen::CallArg> TensorExprKernel::prepareRunArgs(
14681478
}
14691479

14701480
outputs.push_back(at::empty(
1471-
tensorSize, c10::TensorOptions(tensorType(o)).device(device)));
1481+
tensorSize, c10::TensorOptions(tensorType(o)).device(device_)));
14721482
runArgs.emplace_back(outputs.back().data_ptr());
14731483
}
14741484
return runArgs;
14751485
}
14761486

1477-
void TensorExprKernel::lowerToBackend(const at::ArrayRef<IValue>& inputs) {
1478-
checkInputs(inputs, inputTypes_);
1479-
1480-
at::Device device = pickDeviceType(inputs);
1481-
if (!codegenCache_.count(torch::get_hash(device))) {
1482-
BackendType backendType = inferBackendTypeFromDevice(device);
1483-
Stmt* stmt = generateStmt(backendType);
1484-
1485-
// Set up formal params (inputs, then outputs) for kernel.
1486-
std::vector<CodeGen::BufferArg> params = prepareBufferArgs();
1487-
1488-
// Generate code.
1489-
codegenCache_.emplace(
1490-
torch::get_hash(device),
1491-
CreateCodeGen(getCodegenName(backendType), stmt, params, device));
1492-
}
1493-
}
1494-
1495-
void TensorExprKernel::codegenRun(
1496-
at::Device device,
1497-
const std::vector<CodeGen::CallArg>& runArgs) {
1498-
codegenCache_.at(torch::get_hash(device))->call(runArgs);
1499-
}
1500-
1501-
Stmt* TensorExprKernel::getStmtForInputs(const at::ArrayRef<IValue>& inputs) {
1502-
lowerToBackend(inputs);
1503-
at::Device device = pickDeviceType(inputs);
1504-
return codegenCache_.at(torch::get_hash(device))->stmt();
1487+
Stmt* TensorExprKernel::getCodeGenStmt() {
1488+
return codegen_->stmt();
15051489
}
15061490

15071491
void TensorExprKernel::runKernel(Stack& stack) {
15081492
KernelScope kernelScope(&kernelArena_);
1493+
15091494
// Set up arguments (inputs, then outputs) for kernel call.
15101495
auto inputs = last(stack, nInputs_);
1511-
1512-
lowerToBackend(inputs);
1513-
1514-
at::Device device = pickDeviceType(inputs);
1515-
15161496
std::vector<at::Tensor> outputs;
1517-
std::vector<CodeGen::CallArg> runArgs =
1518-
prepareRunArgs(inputs, outputs, device);
1497+
1498+
std::vector<CodeGen::CallArg> runArgs = prepareRunArgs(inputs, outputs);
15191499

15201500
// Call the kernel.
1521-
codegenRun(device, runArgs);
1501+
codegen_->call(runArgs);
15221502

15231503
// Update the stack.
15241504
drop(stack, nInputs_);

torch/csrc/jit/tensorexpr/kernel.h

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class TORCH_API TensorExprKernel {
5252
InterpreterState(code_).run(stack);
5353
}
5454

55-
Stmt* getStmtForInputs(const at::ArrayRef<IValue>& inputs);
55+
Stmt* getCodeGenStmt();
5656

5757
private:
5858
enum BackendType {
@@ -63,7 +63,6 @@ class TORCH_API TensorExprKernel {
6363
};
6464

6565
void compile();
66-
void lowerToBackend(const at::ArrayRef<IValue>& inputs);
6766

6867
void runKernel(Stack& stack);
6968

@@ -160,17 +159,13 @@ class TORCH_API TensorExprKernel {
160159
Stmt* generateStmt(BackendType backendType);
161160
std::vector<CodeGen::BufferArg> prepareBufferArgs();
162161

163-
std::string getCodegenName(BackendType backendType);
164-
void codegenRun(
165-
at::Device device,
166-
const std::vector<CodeGen::CallArg>& runArgs);
162+
std::string getCodeGenName(BackendType backendType);
167163

168164
std::vector<CodeGen::CallArg> prepareRunArgs(
169165
const at::ArrayRef<IValue>& inputs,
170-
std::vector<at::Tensor>& outputs,
171-
at::Device device);
166+
std::vector<at::Tensor>& outputs);
172167
BackendType inferBackendTypeFromDevice(at::Device device);
173-
at::Device pickDeviceType(const at::ArrayRef<IValue>& inputs);
168+
at::Device pickDeviceType(const at::ArrayRef<torch::jit::Value*>& inputs);
174169

175170
void bindInput(const torch::jit::Value* input);
176171

@@ -215,7 +210,8 @@ class TORCH_API TensorExprKernel {
215210
std::vector<Tensor*> flatTensorOutputs_;
216211
std::unordered_map<int64_t, Tensor*> tensors_;
217212
std::unordered_map<int64_t, VarHandle> scalars_;
218-
std::unordered_map<size_t, std::unique_ptr<CodeGen>> codegenCache_;
213+
std::unique_ptr<CodeGen> codegen_;
214+
at::Device device_ = at::kCPU;
219215
KernelArena kernelArena_;
220216
std::vector<TypePtr> inputTypes_;
221217
std::shared_ptr<Graph> graph_;

0 commit comments

Comments
 (0)