When splitting a block reduction from grid reduction there's an error in the work buffer computation. It goes out of memory and in the below case it ends up screwing up the sync buffer so nothing is ever written.
Repro:
Fusion fusion;
FusionGuard fg(&fusion);
// Symbolic integers we will use for runtime tiling
Int* symbolic_m_tile_dim = new Int();
Int* symbolic_split_k_tile_dim = new Int();
Int* symbolic_block_k_tile_dim = new Int();
// Compile-time integer for tiling
int n_smem_tile = 32;
// Symbolic 2D tensors TV0[M, K], TV1[K, N]
TensorView* tv0 = makeDummyTensor(2);
TensorView* tv1 = makeDummyTensor(2);
// Broadcast tv0 to [M, K, *]
TensorView* tv2 = broadcast(tv0, {false, false, true});
// Broadcast tv1 to [*, K, N]
TensorView* tv3 = broadcast(tv1, {true, false, false});
// Pointwise multiplication resulting in tv3[M, K, N]
TensorView* tv4 = mul(tv2, tv3);
// Sum the K-dim
TensorView* tv5 = sum(tv4, {1});
// Register inputs and outputs
fusion.addInput(tv0);
fusion.addInput(tv1);
fusion.addOutput(tv5);
// Register runtime tile dims as inputs
fusion.addInput(symbolic_m_tile_dim);
fusion.addInput(symbolic_split_k_tile_dim);
fusion.addInput(symbolic_block_k_tile_dim);
// Make a 3D tile, mix of symbolic and constant, do in reverse order because
// dims are inserted
tv5->split(2, n_smem_tile);
tv5->split(1, symbolic_block_k_tile_dim);
tv5->split(1, symbolic_split_k_tile_dim);
tv5->split(0, symbolic_m_tile_dim);
// tv5[M/m_tile, m_tile, r{K/split_k/block_k}, r{split_k}, r{block_k}, N/32,
// 32]
tv5->reorder({{1, 5}, {5, 1}});
// tv5[M/m_tile, N/32, r{K/split_k/block_k}, r{split_k}, r{block_k}, m_tile,
// 32]
auto tv6 = tv5->rFactor({2});
auto tv7 = tv5->rFactor({2});
// Scope computations
tv6->computeAt(tv5, 2);
tv6->reorder({
{2, -2},
{3, -1},
{4, 2},
{5, 3},
{6, 4},
});
tv0->computeAt(tv6, 3);
tv1->computeAt(tv6, 3);
tv4->computeAt(tv6, -1);
// Cache smem tiles
tv2->setMemoryType(MemoryType::Shared);
tv3->setMemoryType(MemoryType::Shared);
tv4->setMemoryType(MemoryType::Local);
tv6->setMemoryType(MemoryType::Local);
tv7->setMemoryType(MemoryType::Local);
tv5->axis(0)->parallelize(ParallelType::BIDz);
tv5->axis(1)->parallelize(ParallelType::BIDy);
std::vector<TensorView*> tv_list = {tv2, tv3, tv4, tv5, tv6, tv7};
for (auto tv : tv_list) {
tv->axis(-2)->parallelize(ParallelType::TIDz);
tv->axis(-1)->parallelize(ParallelType::TIDy);
}
tv2->axis(3)->parallelize(ParallelType::TIDx);
tv3->axis(3)->parallelize(ParallelType::TIDx);
tv4->axis(3)->parallelize(ParallelType::TIDx);
tv6->axis(3)->parallelize(ParallelType::TIDx);
tv7->axis(2)->parallelize(ParallelType::TIDx);
tv2->axis(4)->parallelize(ParallelType::BIDx);
tv3->axis(4)->parallelize(ParallelType::BIDx);
tv4->axis(4)->parallelize(ParallelType::BIDx);
tv6->axis(4)->parallelize(ParallelType::BIDx);
tv7->axis(3)->parallelize(ParallelType::BIDx);
tv5->axis(2)->parallelize(ParallelType::BIDx);
fusion.printMath();
fusion.printKernel();
constexpr int M = 3, K = 6, N = 16;
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor t0 = at::randn({M, K}, options);
at::Tensor t1 = at::randn({K, N}, options);
torch::jit::fuser::cuda::FusionExecutor fe;
fe.compileFusion(&fusion);
auto outputs = fe.runFusion(
{t0, t1, 2, 2, 3},
torch::jit::fuser::cuda::LaunchParams(-1, -1, -1, -1, -1, -1));
at::Tensor aten_output = mul(t0.unsqueeze(2), t1.unsqueeze(0)).sum(1);
std::cout << aten_output << "\n" << outputs[0] << std::endl;
TORCH_CHECK(
aten_output.allclose(outputs[0], 1e-5, 1e-5),
"Error of: ",
aten_output.sub(outputs[0]).abs().max());
Removing the TV7 rfactor and parallelizing for block and grid reduce on TV5 seems to work (with other modifications from today's ToT).
When splitting a block reduction from grid reduction there's an error in the work buffer computation. It goes out of memory and in the below case it ends up screwing up the sync buffer so nothing is ever written.
Repro:
Removing the TV7 rfactor and parallelizing for block and grid reduce on TV5 seems to work (with other modifications from today's ToT).