Skip to content

Commit 067ad31

Browse files
Chilleefacebook-github-bot
authored andcommitted
[NNC] Added some more external function bindings (#53420)
Summary: Fixes #{issue number} Pull Request resolved: #53420 Reviewed By: navahgar Differential Revision: D26876784 Pulled By: Chillee fbshipit-source-id: 05e7c782a72de5159879f88a104f1a273e0345eb
1 parent c72473f commit 067ad31

2 files changed

Lines changed: 233 additions & 37 deletions

File tree

test/cpp/tensorexpr/test_external_calls.cpp

Lines changed: 142 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -191,50 +191,157 @@ TEST(ExternalCall, Conv2d_nobias_noargs) {
191191
ASSERT_TRUE(at::allclose(nnc_result, ref));
192192
}
193193

194-
TEST(ExternalCall, Matmul) {
194+
TEST(ExternalCall, BinaryFloat) {
195195
KernelScope kernel_scope;
196-
Placeholder A("A", kFloat, {10, 3, 100, 200});
197-
Placeholder B("", kFloat, {10, 3, 200, 300});
198-
BufHandle ResultBuf("Result", {10, 3, 100, 300}, kFloat);
196+
using TensorFunc = std::function<at::Tensor(at::Tensor, at::Tensor)>;
197+
using Test = std::tuple<
198+
std::vector<int64_t>,
199+
std::vector<int64_t>,
200+
std::vector<int64_t>,
201+
TensorFunc,
202+
std::string>;
203+
std::vector<Test> tests = {};
204+
tests.push_back(
205+
Test{{100, 200}, {200, 300}, {100, 300}, at::matmul, "nnc_aten_matmul"});
206+
tests.push_back(Test{{100, 300}, {300}, {100}, at::mv, "nnc_aten_mv"});
207+
tests.push_back(
208+
Test{{100, 200}, {200, 300}, {100, 300}, at::mm, "nnc_aten_mm"});
209+
for (auto curTest : tests) {
210+
std::vector<int64_t> aShape, bShape, resShape;
211+
TensorFunc torchFunc;
212+
std::string externCallName;
213+
std::tie(aShape, bShape, resShape, torchFunc, externCallName) = curTest;
214+
auto toExprHandleVec = [](std::vector<int64_t> v) {
215+
auto intV = std::vector<int>(v.begin(), v.end());
216+
return std::vector<ExprHandle>(intV.begin(), intV.end());
217+
};
218+
Placeholder A("A", kFloat, toExprHandleVec(aShape));
219+
Placeholder B("", kFloat, toExprHandleVec(bShape));
220+
BufHandle ResultBuf("Result", toExprHandleVec(resShape), kFloat);
221+
222+
Tensor* Result = new Tensor(
223+
ResultBuf.node(),
224+
ExternalCall::make(
225+
ResultBuf,
226+
externCallName,
227+
{BufHandle(A.data()), BufHandle(B.data())},
228+
{}));
229+
LoopNest l({Result});
230+
l.prepareForCodegen();
231+
l.simplify();
232+
233+
auto options = at::TensorOptions()
234+
.dtype(at::kFloat)
235+
.layout(at::kStrided)
236+
.device(at::kCPU)
237+
.requires_grad(false);
238+
at::Tensor a = at::ones(c10::IntArrayRef(aShape), options) * 5.f;
239+
at::Tensor b = at::ones(c10::IntArrayRef(bShape), options) * 6.f;
240+
at::Tensor ref = torchFunc(a, b);
241+
242+
auto prod = [](std::vector<int64_t> v) {
243+
return std::accumulate(v.begin(), v.end(), 1, std::multiplies<int64_t>());
244+
};
245+
246+
at::Tensor nnc_result;
247+
std::vector<float> a_buf(prod(aShape), 5.f);
248+
std::vector<float> b_buf(prod(bShape), 6.f);
249+
std::vector<float> result_buf(prod(resShape), -1.f);
199250

200-
Tensor* Result = new Tensor(
201-
ResultBuf.node(),
202-
ExternalCall::make(
203-
ResultBuf,
204-
"nnc_aten_matmul",
205-
{BufHandle(A.data()), BufHandle(B.data())},
206-
{}));
207-
LoopNest l({Result});
208-
l.prepareForCodegen();
209-
l.simplify();
251+
#ifdef TORCH_ENABLE_LLVM
252+
LLVMCodeGen llvm_codegen(l.root_stmt(), {A, B, Result});
210253

211-
auto options = at::TensorOptions()
212-
.dtype(at::kFloat)
213-
.layout(at::kStrided)
214-
.device(at::kCPU)
215-
.requires_grad(false);
216-
at::Tensor a = at::ones({10, 3, 100, 200}, options) * 5.f;
217-
at::Tensor b = at::ones({10, 3, 200, 300}, options) * 6.f;
218-
at::Tensor ref = at::matmul(a, b);
254+
llvm_codegen.call({a_buf, b_buf, result_buf});
255+
nnc_result =
256+
at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options);
257+
ASSERT_TRUE(at::allclose(nnc_result, ref));
258+
#endif
219259

220-
at::Tensor nnc_result;
221-
std::vector<float> a_buf(10 * 3 * 100 * 200, 5.f);
222-
std::vector<float> b_buf(10 * 3 * 200 * 300, 6.f);
223-
std::vector<float> result_buf(10 * 3 * 100 * 300, -1.f);
260+
SimpleIREvaluator ir_eval(l.root_stmt(), {A, B, Result});
261+
ir_eval.call({a_buf, b_buf, result_buf});
262+
nnc_result =
263+
at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options);
264+
ASSERT_TRUE(at::allclose(nnc_result, ref));
265+
}
266+
}
267+
268+
TEST(ExternalCall, UnaryFloat) {
269+
KernelScope kernel_scope;
270+
using TensorFunc = std::function<at::Tensor(at::Tensor)>;
271+
auto toExprHandleVec = [](std::vector<int64_t> v) {
272+
auto intV = std::vector<int>(v.begin(), v.end());
273+
return std::vector<ExprHandle>(intV.begin(), intV.end());
274+
};
275+
using Test = std::tuple<
276+
std::vector<int64_t>,
277+
std::vector<int64_t>,
278+
TensorFunc,
279+
std::string,
280+
std::vector<ExprHandle>>;
281+
std::vector<Test> tests = {};
282+
tests.push_back(Test{
283+
{1, 64, 8, 9},
284+
{1, 64, 5, 7},
285+
[](at::Tensor x) {
286+
return at::adaptive_avg_pool2d(x, {5, 7});
287+
},
288+
"nnc_aten_adaptive_avg_pool2d",
289+
toExprHandleVec({5, 7})});
290+
tests.push_back(Test{
291+
{100, 200},
292+
{100},
293+
[](at::Tensor x) { return at::mean(x, {1}); },
294+
"nnc_aten_mean",
295+
toExprHandleVec({1})});
296+
for (auto curTest : tests) {
297+
std::vector<int64_t> aShape, resShape;
298+
TensorFunc torchFunc;
299+
std::string externCallName;
300+
std::vector<ExprHandle> externCallArgs;
301+
std::tie(aShape, resShape, torchFunc, externCallName, externCallArgs) =
302+
curTest;
303+
Placeholder A("A", kFloat, toExprHandleVec(aShape));
304+
BufHandle ResultBuf("Result", toExprHandleVec(resShape), kFloat);
305+
306+
Tensor* Result = new Tensor(
307+
ResultBuf.node(),
308+
ExternalCall::make(
309+
ResultBuf, externCallName, {BufHandle(A.data())}, externCallArgs));
310+
LoopNest l({Result});
311+
l.prepareForCodegen();
312+
l.simplify();
313+
314+
auto options = at::TensorOptions()
315+
.dtype(at::kFloat)
316+
.layout(at::kStrided)
317+
.device(at::kCPU)
318+
.requires_grad(false);
319+
at::Tensor a = at::ones(c10::IntArrayRef(aShape), options) * 5.f;
320+
at::Tensor ref = torchFunc(a);
321+
322+
auto prod = [](std::vector<int64_t> v) {
323+
return std::accumulate(v.begin(), v.end(), 1, std::multiplies<int64_t>());
324+
};
325+
326+
at::Tensor nnc_result;
327+
std::vector<float> a_buf(prod(aShape), 5.f);
328+
std::vector<float> result_buf(prod(resShape), -1.f);
224329

225330
#ifdef TORCH_ENABLE_LLVM
226-
LLVMCodeGen llvm_codegen(l.root_stmt(), {A, B, Result});
331+
LLVMCodeGen llvm_codegen(l.root_stmt(), {A, Result});
227332

228-
llvm_codegen.call({a_buf, b_buf, result_buf});
229-
nnc_result = at::from_blob(result_buf.data(), {10, 3, 100, 300}, options);
230-
ASSERT_TRUE(at::allclose(nnc_result, ref));
333+
llvm_codegen.call({a_buf, result_buf});
334+
nnc_result =
335+
at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options);
336+
ASSERT_TRUE(at::allclose(nnc_result, ref));
231337
#endif
232338

233-
SimpleIREvaluator ir_eval(l.root_stmt(), {A, B, Result});
234-
235-
ir_eval.call({a_buf, b_buf, result_buf});
236-
nnc_result = at::from_blob(result_buf.data(), {10, 3, 100, 300}, options);
237-
ASSERT_TRUE(at::allclose(nnc_result, ref));
339+
SimpleIREvaluator ir_eval(l.root_stmt(), {A, Result});
340+
ir_eval.call({a_buf, result_buf});
341+
nnc_result =
342+
at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options);
343+
ASSERT_TRUE(at::allclose(nnc_result, ref));
344+
}
238345
}
239346

240347
TEST(ExternalCall, ComputeInterop) {

torch/csrc/jit/tensorexpr/external_functions.cpp

Lines changed: 91 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,12 +111,101 @@ void nnc_aten_matmul(
111111
}
112112
}
113113

114-
static RegisterNNCExternalFunction nnc_conv2d(
114+
void nnc_aten_mv(
115+
int64_t bufs_num,
116+
void** buf_data,
117+
int64_t* buf_ranks,
118+
int64_t* buf_dims,
119+
int8_t* buf_dtypes,
120+
int64_t args_num,
121+
int64_t* extra_args) {
122+
std::vector<at::Tensor> tensors =
123+
constructTensors(bufs_num, buf_data, buf_ranks, buf_dims, buf_dtypes);
124+
125+
at::Tensor& r = tensors[0];
126+
const at::Tensor& x = tensors[1];
127+
const at::Tensor& w = tensors[2];
128+
try {
129+
at::mv_out(r, x, w);
130+
} catch (...) {
131+
}
132+
}
133+
134+
void nnc_aten_mm(
135+
int64_t bufs_num,
136+
void** buf_data,
137+
int64_t* buf_ranks,
138+
int64_t* buf_dims,
139+
int8_t* buf_dtypes,
140+
int64_t args_num,
141+
int64_t* extra_args) {
142+
std::vector<at::Tensor> tensors =
143+
constructTensors(bufs_num, buf_data, buf_ranks, buf_dims, buf_dtypes);
144+
145+
at::Tensor& r = tensors[0];
146+
const at::Tensor& x = tensors[1];
147+
const at::Tensor& w = tensors[2];
148+
try {
149+
at::mm_out(r, x, w);
150+
} catch (...) {
151+
}
152+
}
153+
154+
void nnc_aten_adaptive_avg_pool2d(
155+
int64_t bufs_num,
156+
void** buf_data,
157+
int64_t* buf_ranks,
158+
int64_t* buf_dims,
159+
int8_t* buf_dtypes,
160+
int64_t args_num,
161+
int64_t* extra_args) {
162+
std::vector<at::Tensor> tensors =
163+
constructTensors(bufs_num, buf_data, buf_ranks, buf_dims, buf_dtypes);
164+
165+
at::Tensor& r = tensors[0];
166+
const at::Tensor& x = tensors[1];
167+
int64_t H = extra_args[0];
168+
int64_t W = extra_args[1];
169+
try {
170+
at::adaptive_avg_pool2d_out(r, x, {H, W});
171+
} catch (...) {
172+
}
173+
}
174+
175+
void nnc_aten_mean(
176+
int64_t bufs_num,
177+
void** buf_data,
178+
int64_t* buf_ranks,
179+
int64_t* buf_dims,
180+
int8_t* buf_dtypes,
181+
int64_t args_num,
182+
int64_t* extra_args) {
183+
std::vector<at::Tensor> tensors =
184+
constructTensors(bufs_num, buf_data, buf_ranks, buf_dims, buf_dtypes);
185+
186+
at::Tensor& r = tensors[0];
187+
const at::Tensor& x = tensors[1];
188+
int64_t dim = extra_args[0];
189+
try {
190+
at::mean_out(r, x, {dim});
191+
} catch (...) {
192+
}
193+
}
194+
195+
const static RegisterNNCExternalFunction nnc_conv2d(
115196
"nnc_aten_conv2d",
116197
nnc_aten_conv2d);
117-
static RegisterNNCExternalFunction nnc_matmul(
198+
const static RegisterNNCExternalFunction nnc_matmul(
118199
"nnc_aten_matmul",
119200
nnc_aten_matmul);
201+
const static RegisterNNCExternalFunction nnc_mv("nnc_aten_mv", nnc_aten_mv);
202+
const static RegisterNNCExternalFunction nnc_mm("nnc_aten_mm", nnc_aten_mm);
203+
const static RegisterNNCExternalFunction nnc_adaptive_avg_pool2d(
204+
"nnc_aten_adaptive_avg_pool2d",
205+
nnc_aten_adaptive_avg_pool2d);
206+
const static RegisterNNCExternalFunction nnc_mean(
207+
"nnc_aten_mean",
208+
nnc_aten_mean);
120209

121210
} // namespace tensorexpr
122211
} // namespace jit

0 commit comments

Comments
 (0)