But I'm getting assertion error doing so. Looking at the comments in the example, I understand it's probably a case that we do not yet support.
Trying to use ExpressionEvaluator to infer launch configuration.
This modified example should work.
// Evaluate expressions in a simple IR
void testGPU_FusionExprEvalBasic() {
Fusion fusion;
FusionGuard fg(&fusion);
// Create a non-trivial IR
TensorView* tv0 = makeDummyTensor(2);
TensorView* tv1 = makeDummyTensor(2);
fusion.addInput(tv0);
fusion.addInput(tv1);
TensorView* tv2 = add(tv1, new Float(2.0));
TensorView* tv3 = add(tv0, tv2);
fusion.addOutput(tv3);
tv3->split(0, 4);
tv0->computeAt(tv3, 1);
tv1->computeAt(tv3, 1);
tv3->axis(0)->parallelize(ParallelType::BIDx);
tv2->axis(1)->parallelize(ParallelType::Unroll);
tv3->axis(1)->parallelize(ParallelType::Unroll);
tv2->axis(-1)->parallelize(ParallelType::TIDx);
tv3->axis(-1)->parallelize(ParallelType::TIDx);
auto* bid_x = add(tv3->axis(0)->rawExtent(), new Int(0));
auto* tid_x = add(tv3->axis(-1)->rawExtent(), new Int(0));
// This appears to be causing issue;
GPULower gpulw(&fusion);
std::stringstream cdg;
gpulw.printKernel(cdg);
std::cout << cdg.str() << std::endl;
// 1. Create an evaluation context
EvaluationContext eval_context(&fusion);
// 2. Bind values
//
// IMPORTANT:
// a. The bindings are only as stable as the Vals are in the fusion graph
// b. You must use the original (rootDomain) extents
// (ex. `tv0->getRootDomain()[0]->extent()`
// instead of `tv0->axis(0)->extent()`)
eval_context.bind(tv0->getRootDomain()[0]->extent(), 6);
eval_context.bind(tv0->getRootDomain()[1]->extent(), 128);
eval_context.bind(tv1->getRootDomain()[0]->extent(), 6);
eval_context.bind(tv1->getRootDomain()[1]->extent(), 128);
// 3. Evaluate and check result values
TORCH_CHECK(tv2->domain()->nDims() == 3);
checkIntValue(&eval_context, tv2->axis(0)->rawExtent(), 2);
checkIntValue(&eval_context, tv2->axis(1)->rawExtent(), 4);
checkIntValue(&eval_context, tv2->axis(2)->rawExtent(), 128);
TORCH_CHECK(tv3->domain()->nDims() == 3);
checkIntValue(&eval_context, tv3->axis(0)->rawExtent(), 2);
checkIntValue(&eval_context, tv3->axis(1)->rawExtent(), 4);
checkIntValue(&eval_context, tv3->axis(2)->rawExtent(), 128);
const auto bid_x_val = ExpressionEvaluator::evaluate(bid_x, &eval_context);
std::cout << "bid x value " << bid_x_val.value() << std::endl;
const auto tid_x_val = ExpressionEvaluator::evaluate(bid_x, &eval_context);
std::cout << "tid x value " << tid_x_val.value() << std::endl;
}
🚀 Feature
I try to modify the cpp test
FusionExprEvalBasicto have code lowering prior to use expression evaluation.But I'm getting assertion error doing so. Looking at the comments in the example, I understand it's probably a case that we do not yet support.
Motivation
Trying to use ExpressionEvaluator to infer launch configuration.
Pitch
This modified example should work.