@@ -191,50 +191,157 @@ TEST(ExternalCall, Conv2d_nobias_noargs) {
191191 ASSERT_TRUE (at::allclose (nnc_result, ref));
192192}
193193
194- TEST (ExternalCall, Matmul ) {
194+ TEST (ExternalCall, BinaryFloat ) {
195195 KernelScope kernel_scope;
196- Placeholder A (" A" , kFloat , {10 , 3 , 100 , 200 });
197- Placeholder B (" " , kFloat , {10 , 3 , 200 , 300 });
198- BufHandle ResultBuf (" Result" , {10 , 3 , 100 , 300 }, kFloat );
196+ using TensorFunc = std::function<at::Tensor (at::Tensor, at::Tensor)>;
197+ using Test = std::tuple<
198+ std::vector<int64_t >,
199+ std::vector<int64_t >,
200+ std::vector<int64_t >,
201+ TensorFunc,
202+ std::string>;
203+ std::vector<Test> tests = {};
204+ tests.push_back (
205+ Test{{100 , 200 }, {200 , 300 }, {100 , 300 }, at::matmul, " nnc_aten_matmul" });
206+ tests.push_back (Test{{100 , 300 }, {300 }, {100 }, at::mv, " nnc_aten_mv" });
207+ tests.push_back (
208+ Test{{100 , 200 }, {200 , 300 }, {100 , 300 }, at::mm, " nnc_aten_mm" });
209+ for (auto curTest : tests) {
210+ std::vector<int64_t > aShape, bShape, resShape;
211+ TensorFunc torchFunc;
212+ std::string externCallName;
213+ std::tie (aShape, bShape, resShape, torchFunc, externCallName) = curTest;
214+ auto toExprHandleVec = [](std::vector<int64_t > v) {
215+ auto intV = std::vector<int >(v.begin (), v.end ());
216+ return std::vector<ExprHandle>(intV.begin (), intV.end ());
217+ };
218+ Placeholder A (" A" , kFloat , toExprHandleVec (aShape));
219+ Placeholder B (" " , kFloat , toExprHandleVec (bShape));
220+ BufHandle ResultBuf (" Result" , toExprHandleVec (resShape), kFloat );
221+
222+ Tensor* Result = new Tensor (
223+ ResultBuf.node (),
224+ ExternalCall::make (
225+ ResultBuf,
226+ externCallName,
227+ {BufHandle (A.data ()), BufHandle (B.data ())},
228+ {}));
229+ LoopNest l ({Result});
230+ l.prepareForCodegen ();
231+ l.simplify ();
232+
233+ auto options = at::TensorOptions ()
234+ .dtype (at::kFloat )
235+ .layout (at::kStrided )
236+ .device (at::kCPU )
237+ .requires_grad (false );
238+ at::Tensor a = at::ones (c10::IntArrayRef (aShape), options) * 5 .f ;
239+ at::Tensor b = at::ones (c10::IntArrayRef (bShape), options) * 6 .f ;
240+ at::Tensor ref = torchFunc (a, b);
241+
242+ auto prod = [](std::vector<int64_t > v) {
243+ return std::accumulate (v.begin (), v.end (), 1 , std::multiplies<int64_t >());
244+ };
245+
246+ at::Tensor nnc_result;
247+ std::vector<float > a_buf (prod (aShape), 5 .f );
248+ std::vector<float > b_buf (prod (bShape), 6 .f );
249+ std::vector<float > result_buf (prod (resShape), -1 .f );
199250
200- Tensor* Result = new Tensor (
201- ResultBuf.node (),
202- ExternalCall::make (
203- ResultBuf,
204- " nnc_aten_matmul" ,
205- {BufHandle (A.data ()), BufHandle (B.data ())},
206- {}));
207- LoopNest l ({Result});
208- l.prepareForCodegen ();
209- l.simplify ();
251+ #ifdef TORCH_ENABLE_LLVM
252+ LLVMCodeGen llvm_codegen (l.root_stmt (), {A, B, Result});
210253
211- auto options = at::TensorOptions ()
212- .dtype (at::kFloat )
213- .layout (at::kStrided )
214- .device (at::kCPU )
215- .requires_grad (false );
216- at::Tensor a = at::ones ({10 , 3 , 100 , 200 }, options) * 5 .f ;
217- at::Tensor b = at::ones ({10 , 3 , 200 , 300 }, options) * 6 .f ;
218- at::Tensor ref = at::matmul (a, b);
254+ llvm_codegen.call ({a_buf, b_buf, result_buf});
255+ nnc_result =
256+ at::from_blob (result_buf.data (), c10::IntArrayRef (resShape), options);
257+ ASSERT_TRUE (at::allclose (nnc_result, ref));
258+ #endif
219259
220- at::Tensor nnc_result;
221- std::vector<float > a_buf (10 * 3 * 100 * 200 , 5 .f );
222- std::vector<float > b_buf (10 * 3 * 200 * 300 , 6 .f );
223- std::vector<float > result_buf (10 * 3 * 100 * 300 , -1 .f );
260+ SimpleIREvaluator ir_eval (l.root_stmt (), {A, B, Result});
261+ ir_eval.call ({a_buf, b_buf, result_buf});
262+ nnc_result =
263+ at::from_blob (result_buf.data (), c10::IntArrayRef (resShape), options);
264+ ASSERT_TRUE (at::allclose (nnc_result, ref));
265+ }
266+ }
267+
268+ TEST (ExternalCall, UnaryFloat) {
269+ KernelScope kernel_scope;
270+ using TensorFunc = std::function<at::Tensor (at::Tensor)>;
271+ auto toExprHandleVec = [](std::vector<int64_t > v) {
272+ auto intV = std::vector<int >(v.begin (), v.end ());
273+ return std::vector<ExprHandle>(intV.begin (), intV.end ());
274+ };
275+ using Test = std::tuple<
276+ std::vector<int64_t >,
277+ std::vector<int64_t >,
278+ TensorFunc,
279+ std::string,
280+ std::vector<ExprHandle>>;
281+ std::vector<Test> tests = {};
282+ tests.push_back (Test{
283+ {1 , 64 , 8 , 9 },
284+ {1 , 64 , 5 , 7 },
285+ [](at::Tensor x) {
286+ return at::adaptive_avg_pool2d (x, {5 , 7 });
287+ },
288+ " nnc_aten_adaptive_avg_pool2d" ,
289+ toExprHandleVec ({5 , 7 })});
290+ tests.push_back (Test{
291+ {100 , 200 },
292+ {100 },
293+ [](at::Tensor x) { return at::mean (x, {1 }); },
294+ " nnc_aten_mean" ,
295+ toExprHandleVec ({1 })});
296+ for (auto curTest : tests) {
297+ std::vector<int64_t > aShape, resShape;
298+ TensorFunc torchFunc;
299+ std::string externCallName;
300+ std::vector<ExprHandle> externCallArgs;
301+ std::tie (aShape, resShape, torchFunc, externCallName, externCallArgs) =
302+ curTest;
303+ Placeholder A (" A" , kFloat , toExprHandleVec (aShape));
304+ BufHandle ResultBuf (" Result" , toExprHandleVec (resShape), kFloat );
305+
306+ Tensor* Result = new Tensor (
307+ ResultBuf.node (),
308+ ExternalCall::make (
309+ ResultBuf, externCallName, {BufHandle (A.data ())}, externCallArgs));
310+ LoopNest l ({Result});
311+ l.prepareForCodegen ();
312+ l.simplify ();
313+
314+ auto options = at::TensorOptions ()
315+ .dtype (at::kFloat )
316+ .layout (at::kStrided )
317+ .device (at::kCPU )
318+ .requires_grad (false );
319+ at::Tensor a = at::ones (c10::IntArrayRef (aShape), options) * 5 .f ;
320+ at::Tensor ref = torchFunc (a);
321+
322+ auto prod = [](std::vector<int64_t > v) {
323+ return std::accumulate (v.begin (), v.end (), 1 , std::multiplies<int64_t >());
324+ };
325+
326+ at::Tensor nnc_result;
327+ std::vector<float > a_buf (prod (aShape), 5 .f );
328+ std::vector<float > result_buf (prod (resShape), -1 .f );
224329
225330#ifdef TORCH_ENABLE_LLVM
226- LLVMCodeGen llvm_codegen (l.root_stmt (), {A, B , Result});
331+ LLVMCodeGen llvm_codegen (l.root_stmt (), {A, Result});
227332
228- llvm_codegen.call ({a_buf, b_buf, result_buf});
229- nnc_result = at::from_blob (result_buf.data (), {10 , 3 , 100 , 300 }, options);
230- ASSERT_TRUE (at::allclose (nnc_result, ref));
333+ llvm_codegen.call ({a_buf, result_buf});
334+ nnc_result =
335+ at::from_blob (result_buf.data (), c10::IntArrayRef (resShape), options);
336+ ASSERT_TRUE (at::allclose (nnc_result, ref));
231337#endif
232338
233- SimpleIREvaluator ir_eval (l.root_stmt (), {A, B, Result});
234-
235- ir_eval.call ({a_buf, b_buf, result_buf});
236- nnc_result = at::from_blob (result_buf.data (), {10 , 3 , 100 , 300 }, options);
237- ASSERT_TRUE (at::allclose (nnc_result, ref));
339+ SimpleIREvaluator ir_eval (l.root_stmt (), {A, Result});
340+ ir_eval.call ({a_buf, result_buf});
341+ nnc_result =
342+ at::from_blob (result_buf.data (), c10::IntArrayRef (resShape), options);
343+ ASSERT_TRUE (at::allclose (nnc_result, ref));
344+ }
238345}
239346
240347TEST (ExternalCall, ComputeInterop) {
0 commit comments