Skip to content

Issue with block reduction then grid reduction #367

@csarofeen

Description

@csarofeen

When splitting a block reduction from grid reduction there's an error in the work buffer computation. It goes out of memory and in the below case it ends up screwing up the sync buffer so nothing is ever written.
Repro:

  Fusion fusion;
  FusionGuard fg(&fusion);

  // Symbolic integers we will use for runtime tiling
  Int* symbolic_m_tile_dim = new Int();
  Int* symbolic_split_k_tile_dim = new Int();
  Int* symbolic_block_k_tile_dim = new Int();
  // Compile-time integer for tiling
  int n_smem_tile = 32;

  // Symbolic 2D tensors TV0[M, K], TV1[K, N]
  TensorView* tv0 = makeDummyTensor(2);
  TensorView* tv1 = makeDummyTensor(2);

  // Broadcast tv0 to [M, K, *]
  TensorView* tv2 = broadcast(tv0, {false, false, true});
  // Broadcast tv1 to [*, K, N]
  TensorView* tv3 = broadcast(tv1, {true, false, false});

  // Pointwise multiplication resulting in tv3[M, K, N]
  TensorView* tv4 = mul(tv2, tv3);

  // Sum the K-dim
  TensorView* tv5 = sum(tv4, {1});

  // Register inputs and outputs
  fusion.addInput(tv0);
  fusion.addInput(tv1);
  fusion.addOutput(tv5);

  // Register runtime tile dims as inputs
  fusion.addInput(symbolic_m_tile_dim);
  fusion.addInput(symbolic_split_k_tile_dim);
  fusion.addInput(symbolic_block_k_tile_dim);

  // Make a 3D tile, mix of symbolic and constant, do in reverse order because
  // dims are inserted
  tv5->split(2, n_smem_tile);
  tv5->split(1, symbolic_block_k_tile_dim);
  tv5->split(1, symbolic_split_k_tile_dim);
  tv5->split(0, symbolic_m_tile_dim);

  // tv5[M/m_tile, m_tile, r{K/split_k/block_k}, r{split_k}, r{block_k}, N/32,
  // 32]
  tv5->reorder({{1, 5}, {5, 1}});
  // tv5[M/m_tile, N/32, r{K/split_k/block_k}, r{split_k}, r{block_k},  m_tile,
  // 32]

  auto tv6 = tv5->rFactor({2});
  auto tv7 = tv5->rFactor({2});

  // Scope computations
  tv6->computeAt(tv5, 2);

  tv6->reorder({
      {2, -2},
      {3, -1},
      {4, 2},
      {5, 3},
      {6, 4},
  });

  tv0->computeAt(tv6, 3);
  tv1->computeAt(tv6, 3);
  tv4->computeAt(tv6, -1);

  // Cache smem tiles
  tv2->setMemoryType(MemoryType::Shared);
  tv3->setMemoryType(MemoryType::Shared);
  tv4->setMemoryType(MemoryType::Local);
  tv6->setMemoryType(MemoryType::Local);
  tv7->setMemoryType(MemoryType::Local);

  tv5->axis(0)->parallelize(ParallelType::BIDz);
  tv5->axis(1)->parallelize(ParallelType::BIDy);

  std::vector<TensorView*> tv_list = {tv2, tv3, tv4, tv5, tv6, tv7};
  for (auto tv : tv_list) {
    tv->axis(-2)->parallelize(ParallelType::TIDz);
    tv->axis(-1)->parallelize(ParallelType::TIDy);
  }
  tv2->axis(3)->parallelize(ParallelType::TIDx);
  tv3->axis(3)->parallelize(ParallelType::TIDx);
  tv4->axis(3)->parallelize(ParallelType::TIDx);
  tv6->axis(3)->parallelize(ParallelType::TIDx);
  tv7->axis(2)->parallelize(ParallelType::TIDx);

  tv2->axis(4)->parallelize(ParallelType::BIDx);
  tv3->axis(4)->parallelize(ParallelType::BIDx);
  tv4->axis(4)->parallelize(ParallelType::BIDx);
  tv6->axis(4)->parallelize(ParallelType::BIDx);
  tv7->axis(3)->parallelize(ParallelType::BIDx);
  tv5->axis(2)->parallelize(ParallelType::BIDx);

  fusion.printMath();
  fusion.printKernel();

  constexpr int M = 3, K = 6, N = 16;

  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
  at::Tensor t0 = at::randn({M, K}, options);
  at::Tensor t1 = at::randn({K, N}, options);

  torch::jit::fuser::cuda::FusionExecutor fe;
  fe.compileFusion(&fusion);
  auto outputs = fe.runFusion(
      {t0, t1, 2, 2, 3},
      torch::jit::fuser::cuda::LaunchParams(-1, -1, -1, -1, -1, -1));

  at::Tensor aten_output = mul(t0.unsqueeze(2), t1.unsqueeze(0)).sum(1);
  std::cout << aten_output << "\n" << outputs[0] << std::endl;
  TORCH_CHECK(
      aten_output.allclose(outputs[0], 1e-5, 1e-5),
      "Error of: ",
      aten_output.sub(outputs[0]).abs().max());

Removing the TV7 rfactor and parallelizing for block and grid reduce on TV5 seems to work (with other modifications from today's ToT).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions