@@ -1201,6 +1201,152 @@ void testLLVMEliminatedStmt() {
12011201 cg.call ({aData, cData});
12021202}
12031203
1204+ void testLLVMSimpleReduction () {
1205+ KernelScope kernel_scope;
1206+
1207+ int M = 128 ;
1208+ int N = 64 ;
1209+ const int kTotalSize = M * N;
1210+
1211+ Buffer a (" a" , kFloat , {1 , M, N});
1212+
1213+ // TODO: why doesn't implicit vector<DimArg> work?
1214+ std::vector<DimArg> axis = {DimArg (1 )};
1215+ std::vector<DimArg> reduce_axis = {DimArg (M), DimArg (N)};
1216+ Tensor* b = Reduce (" sum" , axis, Sum (), a, reduce_axis);
1217+ LoopNest loop ({b});
1218+
1219+ loop.prepareForCodegen ();
1220+ Stmt* s = loop.root_stmt ();
1221+ s = IRSimplifier::simplify (s);
1222+
1223+ LLVMCodeGen cg (s, {a, b});
1224+
1225+ PaddedBuffer<float > a_v (1 , M, N, " a_v" );
1226+ PaddedBuffer<float > b_v (1 , " b_v" );
1227+ PaddedBuffer<float > b_ref (1 , " b_ref" );
1228+
1229+ b_ref (0 ) = 0 ;
1230+ for (int i = 0 ; i < M; i++) {
1231+ for (int j = 0 ; j < N; j++) {
1232+ int v = i + j;
1233+ a_v (0 , i, j) = v;
1234+ b_ref (0 ) += v;
1235+ }
1236+ }
1237+
1238+ cg.call ({a_v, b_v});
1239+
1240+ ExpectAllNear (b_v, b_ref, 1e-5 );
1241+ }
1242+
1243+ void testLLVMRFactorReduction () {
1244+ KernelScope kernel_scope;
1245+
1246+ int M = 128 ;
1247+ int N = 64 ;
1248+ const int kTotalSize = M * N;
1249+
1250+ Buffer a (" a" , kFloat , {1 , M, N});
1251+
1252+ // TODO: why doesn't implicit vector<DimArg> work?
1253+ std::vector<DimArg> axis = {DimArg (1 )};
1254+ std::vector<DimArg> reduce_axis = {DimArg (M), DimArg (N)};
1255+ Tensor* b = Reduce (" sum" , axis, Sum (), a, reduce_axis);
1256+ LoopNest loop ({b});
1257+
1258+ std::vector<For*> loops = loop.getLoopStmtsFor (b);
1259+ For* loop_m = loops.at (1 );
1260+ For* loop_n = loops.at (2 );
1261+ loop.reorderAxis (b, loop_m, loop_n);
1262+
1263+ loops = loop.getLoopStmtsFor (b);
1264+ loop_m = loops.at (2 );
1265+ loop_n = loops.at (1 );
1266+ loop.rfactor (b->body (), loop_n->var (), loop_n->body ());
1267+
1268+ loop.prepareForCodegen ();
1269+ Stmt* s = loop.root_stmt ();
1270+ s = IRSimplifier::simplify (s);
1271+
1272+ LLVMCodeGen cg (s, {a, b});
1273+
1274+ PaddedBuffer<float > a_v (1 , M, N, " a_v" );
1275+ PaddedBuffer<float > b_v (1 , " b_v" );
1276+ PaddedBuffer<float > b_ref (1 , " b_ref" );
1277+
1278+ b_ref (0 ) = 0 ;
1279+ for (int i = 0 ; i < M; i++) {
1280+ for (int j = 0 ; j < N; j++) {
1281+ int v = i + j;
1282+ a_v (0 , i, j) = v;
1283+ b_ref (0 ) += v;
1284+ }
1285+ }
1286+
1287+ cg.call ({a_v, b_v});
1288+
1289+ ExpectAllNear (b_v, b_ref, 1e-5 );
1290+ }
1291+
1292+ void testLLVMRFactorVectorizedReduction () {
1293+ KernelScope kernel_scope;
1294+
1295+ int M = 128 ;
1296+ int N = 64 ;
1297+ const int kTotalSize = M * N;
1298+
1299+ Buffer a (" a" , kFloat , {1 , M, N});
1300+
1301+ // TODO: why doesn't implicit vector<DimArg> work?
1302+ std::vector<DimArg> axis = {DimArg (1 )};
1303+ std::vector<DimArg> reduce_axis = {DimArg (M), DimArg (N)};
1304+ Tensor* b = Reduce (" sum" , axis, Sum (), a, reduce_axis);
1305+ LoopNest loopnest ({b});
1306+ std::vector<For*> loops = loopnest.getLoopStmtsFor (b);
1307+ For* loop_k = loops.at (0 );
1308+ For* loop_m = loops.at (1 );
1309+ For* loop_n = loops.at (2 );
1310+ loopnest.reorderAxis (b, loop_n, loop_m);
1311+ loops = loopnest.getLoopStmtsFor (b);
1312+ loop_k = loops.at (0 );
1313+ loop_n = loops.at (1 );
1314+ loop_m = loops.at (2 );
1315+ // Case-III reductions
1316+ loopnest.rfactor (b->body (), loop_n->var ());
1317+ loopnest.prepareForCodegen ();
1318+ Stmt* s = loopnest.root_stmt ();
1319+ s = IRSimplifier::simplify (s);
1320+
1321+ Block* root_block = dynamic_cast <Block*>(s);
1322+ auto stmt_list = root_block->stmts ();
1323+ auto I = stmt_list.begin ();
1324+ ++I;
1325+
1326+ For* outer_loop = dynamic_cast <For*>(*I);
1327+ loopnest.vectorize (outer_loop);
1328+
1329+ s = IRSimplifier::simplify (s);
1330+ LLVMCodeGen cg (s, {a, b});
1331+
1332+ PaddedBuffer<float > a_v (1 , M, N, " a_v" );
1333+ PaddedBuffer<float > b_v (1 , " b_v" );
1334+ PaddedBuffer<float > b_ref (1 , " b_ref" );
1335+
1336+ b_ref (0 ) = 0 ;
1337+ for (int i = 0 ; i < M; i++) {
1338+ for (int j = 0 ; j < N; j++) {
1339+ int v = i + j;
1340+ a_v (0 , i, j) = v;
1341+ b_ref (0 ) += v;
1342+ }
1343+ }
1344+
1345+ cg.call ({a_v, b_v});
1346+
1347+ ExpectAllNear (b_v, b_ref, 1e-5 );
1348+ }
1349+
12041350} // namespace jit
12051351} // namespace torch
12061352
0 commit comments