Skip to content

Commit 20ba29d

Browse files
resistorfacebook-github-bot
authored andcommitted
Add support for reductions on CPU in tensorexpr (#37333)
Summary: Pull Request resolved: #37333 Differential Revision: D21290289 Pulled By: resistor fbshipit-source-id: ebba11f7af9e22b48c47e2eefb9497fa77acd17d
1 parent d3d10cc commit 20ba29d

4 files changed

Lines changed: 211 additions & 12 deletions

File tree

test/cpp/tensorexpr/test_llvm.cpp

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1201,6 +1201,152 @@ void testLLVMEliminatedStmt() {
12011201
cg.call({aData, cData});
12021202
}
12031203

1204+
void testLLVMSimpleReduction() {
1205+
KernelScope kernel_scope;
1206+
1207+
int M = 128;
1208+
int N = 64;
1209+
const int kTotalSize = M * N;
1210+
1211+
Buffer a("a", kFloat, {1, M, N});
1212+
1213+
// TODO: why doesn't implicit vector<DimArg> work?
1214+
std::vector<DimArg> axis = {DimArg(1)};
1215+
std::vector<DimArg> reduce_axis = {DimArg(M), DimArg(N)};
1216+
Tensor* b = Reduce("sum", axis, Sum(), a, reduce_axis);
1217+
LoopNest loop({b});
1218+
1219+
loop.prepareForCodegen();
1220+
Stmt* s = loop.root_stmt();
1221+
s = IRSimplifier::simplify(s);
1222+
1223+
LLVMCodeGen cg(s, {a, b});
1224+
1225+
PaddedBuffer<float> a_v(1, M, N, "a_v");
1226+
PaddedBuffer<float> b_v(1, "b_v");
1227+
PaddedBuffer<float> b_ref(1, "b_ref");
1228+
1229+
b_ref(0) = 0;
1230+
for (int i = 0; i < M; i++) {
1231+
for (int j = 0; j < N; j++) {
1232+
int v = i + j;
1233+
a_v(0, i, j) = v;
1234+
b_ref(0) += v;
1235+
}
1236+
}
1237+
1238+
cg.call({a_v, b_v});
1239+
1240+
ExpectAllNear(b_v, b_ref, 1e-5);
1241+
}
1242+
1243+
void testLLVMRFactorReduction() {
1244+
KernelScope kernel_scope;
1245+
1246+
int M = 128;
1247+
int N = 64;
1248+
const int kTotalSize = M * N;
1249+
1250+
Buffer a("a", kFloat, {1, M, N});
1251+
1252+
// TODO: why doesn't implicit vector<DimArg> work?
1253+
std::vector<DimArg> axis = {DimArg(1)};
1254+
std::vector<DimArg> reduce_axis = {DimArg(M), DimArg(N)};
1255+
Tensor* b = Reduce("sum", axis, Sum(), a, reduce_axis);
1256+
LoopNest loop({b});
1257+
1258+
std::vector<For*> loops = loop.getLoopStmtsFor(b);
1259+
For* loop_m = loops.at(1);
1260+
For* loop_n = loops.at(2);
1261+
loop.reorderAxis(b, loop_m, loop_n);
1262+
1263+
loops = loop.getLoopStmtsFor(b);
1264+
loop_m = loops.at(2);
1265+
loop_n = loops.at(1);
1266+
loop.rfactor(b->body(), loop_n->var(), loop_n->body());
1267+
1268+
loop.prepareForCodegen();
1269+
Stmt* s = loop.root_stmt();
1270+
s = IRSimplifier::simplify(s);
1271+
1272+
LLVMCodeGen cg(s, {a, b});
1273+
1274+
PaddedBuffer<float> a_v(1, M, N, "a_v");
1275+
PaddedBuffer<float> b_v(1, "b_v");
1276+
PaddedBuffer<float> b_ref(1, "b_ref");
1277+
1278+
b_ref(0) = 0;
1279+
for (int i = 0; i < M; i++) {
1280+
for (int j = 0; j < N; j++) {
1281+
int v = i + j;
1282+
a_v(0, i, j) = v;
1283+
b_ref(0) += v;
1284+
}
1285+
}
1286+
1287+
cg.call({a_v, b_v});
1288+
1289+
ExpectAllNear(b_v, b_ref, 1e-5);
1290+
}
1291+
1292+
void testLLVMRFactorVectorizedReduction() {
1293+
KernelScope kernel_scope;
1294+
1295+
int M = 128;
1296+
int N = 64;
1297+
const int kTotalSize = M * N;
1298+
1299+
Buffer a("a", kFloat, {1, M, N});
1300+
1301+
// TODO: why doesn't implicit vector<DimArg> work?
1302+
std::vector<DimArg> axis = {DimArg(1)};
1303+
std::vector<DimArg> reduce_axis = {DimArg(M), DimArg(N)};
1304+
Tensor* b = Reduce("sum", axis, Sum(), a, reduce_axis);
1305+
LoopNest loopnest({b});
1306+
std::vector<For*> loops = loopnest.getLoopStmtsFor(b);
1307+
For* loop_k = loops.at(0);
1308+
For* loop_m = loops.at(1);
1309+
For* loop_n = loops.at(2);
1310+
loopnest.reorderAxis(b, loop_n, loop_m);
1311+
loops = loopnest.getLoopStmtsFor(b);
1312+
loop_k = loops.at(0);
1313+
loop_n = loops.at(1);
1314+
loop_m = loops.at(2);
1315+
// Case-III reductions
1316+
loopnest.rfactor(b->body(), loop_n->var());
1317+
loopnest.prepareForCodegen();
1318+
Stmt* s = loopnest.root_stmt();
1319+
s = IRSimplifier::simplify(s);
1320+
1321+
Block* root_block = dynamic_cast<Block*>(s);
1322+
auto stmt_list = root_block->stmts();
1323+
auto I = stmt_list.begin();
1324+
++I;
1325+
1326+
For* outer_loop = dynamic_cast<For*>(*I);
1327+
loopnest.vectorize(outer_loop);
1328+
1329+
s = IRSimplifier::simplify(s);
1330+
LLVMCodeGen cg(s, {a, b});
1331+
1332+
PaddedBuffer<float> a_v(1, M, N, "a_v");
1333+
PaddedBuffer<float> b_v(1, "b_v");
1334+
PaddedBuffer<float> b_ref(1, "b_ref");
1335+
1336+
b_ref(0) = 0;
1337+
for (int i = 0; i < M; i++) {
1338+
for (int j = 0; j < N; j++) {
1339+
int v = i + j;
1340+
a_v(0, i, j) = v;
1341+
b_ref(0) += v;
1342+
}
1343+
}
1344+
1345+
cg.call({a_v, b_v});
1346+
1347+
ExpectAllNear(b_v, b_ref, 1e-5);
1348+
}
1349+
12041350
} // namespace jit
12051351
} // namespace torch
12061352

test/cpp/tensorexpr/tests.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,10 @@ namespace jit {
296296
_(LLVMEmptyStmt) \
297297
_(LLVMEliminatedStmt) \
298298
_(LLVMIfThenElseTest) \
299-
_(LLVMVectorizerLoadStoreTest)
299+
_(LLVMVectorizerLoadStoreTest) \
300+
_(LLVMSimpleReduction) \
301+
_(LLVMRFactorReduction) \
302+
_(LLVMRFactorVectorizedReduction)
300303

301304
#define TH_FORALL_TENSOREXPR_TESTS_CUDA(_) \
302305
_(CudaTestVectorAdd01) \

torch/csrc/jit/tensorexpr/llvm_codegen.cpp

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class LLVMCodeGenImpl : public IRVisitor {
4242
llvm::BasicBlock* bb_;
4343
llvm::Value* value_{nullptr};
4444
llvm::JITTargetAddress kernelAddress_;
45+
std::unique_ptr<void* []> argv_ { nullptr };
4546

4647
#define LLVM_TYPE_DECLARE(_1, Name) llvm::Type* Name##Ty_;
4748
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, LLVM_TYPE_DECLARE);
@@ -66,6 +67,7 @@ class LLVMCodeGenImpl : public IRVisitor {
6667
~LLVMCodeGenImpl() = default;
6768

6869
llvm::JITTargetAddress getKernelAddress() const;
70+
void** getArgvAddress() const;
6971

7072
void visit(const Add* v) override;
7173
void visit(const Sub* v) override;
@@ -184,15 +186,16 @@ static void* argToPtr(
184186
}
185187

186188
void LLVMCodeGen::call(const std::vector<CallArg>& args) {
187-
if (args.size() != buffer_args().size()) {
189+
const auto& buf_args = buffer_args();
190+
if (args.size() != buf_args.size()) {
188191
throw malformed_input("wrong number of args in call");
189192
}
190193

191-
std::vector<void*> argv;
192-
for (size_t i = 0; i < buffer_args().size(); i++) {
193-
auto const& bufferArg = buffer_args()[i];
194+
void** argv = impl_->getArgvAddress();
195+
for (size_t i = 0, e = buf_args.size(); i < e; i++) {
196+
auto const& bufferArg = buf_args[i];
194197
auto const& callArg = args[i];
195-
argv.push_back(argToPtr(bufferArg, callArg));
198+
argv[i] = argToPtr(bufferArg, callArg);
196199
}
197200
value<float>(argv);
198201
USE_TRIGGER(llvm_codegen_executed);
@@ -206,6 +209,10 @@ llvm::JITTargetAddress LLVMCodeGenImpl::getKernelAddress() const {
206209
return kernelAddress_;
207210
}
208211

212+
void** LLVMCodeGenImpl::getArgvAddress() const {
213+
return argv_.get();
214+
}
215+
209216
LLVMCodeGenImpl::LLVMCodeGenImpl(
210217
Stmt* stmt,
211218
const std::vector<CodeGen::BufferArg>& args,
@@ -261,6 +268,7 @@ LLVMCodeGenImpl::LLVMCodeGenImpl(
261268
llvm::orc::ThreadSafeModule(std::move(module_), context_)));
262269
auto sym = jit_->findSymbol("wrapper");
263270
kernelAddress_ = cantFail(sym.getAddress());
271+
argv_ = std::make_unique<void*[]>(params.size());
264272

265273
USE_TRIGGER(llvm_codegen_created);
266274
}
@@ -919,7 +927,11 @@ void LLVMCodeGenImpl::visit(const For* v) {
919927
// Set up phi node for index variable.
920928
auto idx = irb_.CreatePHI(IntTy_, 2);
921929
idx->addIncoming(start, preheader);
922-
varToVal_.emplace(v->var(), idx);
930+
if (!varToVal_.count(v->var())) {
931+
varToVal_.emplace(v->var(), idx);
932+
} else {
933+
throw std::runtime_error("var should not exist before");
934+
}
923935

924936
// Create the body and exit blocks.
925937
auto body = llvm::BasicBlock::Create(getContext(), "body", fn_);
@@ -944,6 +956,8 @@ void LLVMCodeGenImpl::visit(const For* v) {
944956

945957
// Exit the loop.
946958
irb_.SetInsertPoint(exit);
959+
960+
varToVal_.erase(v->var());
947961
value_ = llvm::ConstantInt::get(IntTy_, 0);
948962
}
949963

@@ -1454,11 +1468,43 @@ void LLVMCodeGenImpl::visit(const FunctionCall* v) {
14541468
}
14551469

14561470
void LLVMCodeGenImpl::visit(const Allocate* v) {
1457-
throw unimplemented_lowering(v);
1471+
llvm::Value* size =
1472+
llvm::ConstantInt::getSigned(LongTy_, v->dtype().byte_size());
1473+
for (const Expr* e : v->dims()) {
1474+
e->accept(this);
1475+
size = irb_.CreateMul(size, irb_.CreateZExt(value_, LongTy_));
1476+
}
1477+
1478+
value_ = llvm::ConstantInt::get(IntTy_, 0);
1479+
1480+
if (llvm::ConstantInt* CI = llvm::dyn_cast<llvm::ConstantInt>(size)) {
1481+
if (CI->getSExtValue() < 512) {
1482+
llvm::Value* alloca = irb_.CreateAlloca(dtypeToLLVM(v->dtype()), size);
1483+
varToVal_[v->buffer_var()] = alloca;
1484+
return;
1485+
}
1486+
}
1487+
1488+
llvm::Instruction* I = llvm::CallInst::CreateMalloc(
1489+
irb_.GetInsertBlock(),
1490+
LongTy_,
1491+
dtypeToLLVM(v->dtype()),
1492+
size,
1493+
nullptr,
1494+
nullptr);
1495+
1496+
// Insert the bitcast into the block.
1497+
irb_.SetInsertPoint(irb_.GetInsertBlock());
1498+
llvm::Value* malloc = irb_.Insert(I);
1499+
varToVal_[v->buffer_var()] = malloc;
14581500
}
14591501

14601502
void LLVMCodeGenImpl::visit(const Free* v) {
1461-
throw unimplemented_lowering(v);
1503+
value_ = llvm::ConstantInt::get(IntTy_, 0);
1504+
llvm::Value* ptr = varToVal_.at(v->buffer_var());
1505+
if (!llvm::isa<llvm::AllocaInst>(ptr)) {
1506+
irb_.Insert(llvm::CallInst::CreateFree(ptr, irb_.GetInsertBlock()));
1507+
}
14621508
}
14631509

14641510
void LLVMCodeGenImpl::visit(const Cond* v) {

torch/csrc/jit/tensorexpr/llvm_codegen.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,18 @@ class TORCH_API LLVMCodeGen : public CodeGen {
3232

3333
template <typename T>
3434
T value() {
35-
std::vector<void*> args;
36-
return value<T>(args);
35+
return value<T>(nullptr);
3736
}
3837

3938
template <typename T>
4039
T value(std::vector<void*>& args) {
40+
return value<T>(args.data());
41+
}
42+
43+
template <typename T>
44+
T value(void** args) {
4145
T (*fp)(void**) = (T(*)(void**))getKernelAddress(impl_.get());
42-
T rv = fp(args.data());
46+
T rv = fp(args);
4347
return rv;
4448
}
4549

0 commit comments

Comments
 (0)