Skip to content

codegen gives wrong code for branching #278

@jjsjann123

Description

@jjsjann123

🐛 Bug

I ran into an issue, where I have two branches in a fusion that join later.
When I naively call computeAt for all inputs to outputs. The generated code screws up the dependency and is using intermediate before it is written to.

Here's the example:

void testGPU_FusionBranches() {
  Fusion fusion;
  FusionGuard fg(&fusion);

  // Set up your input tensor views
  TensorView* tv0 = makeDummyTensor(2);
  TensorView* tv1 = makeDummyTensor(2);
  TensorView* tv2 = makeDummyTensor(2);
  fusion.addInput(tv0);
  fusion.addInput(tv1);
  fusion.addInput(tv2);

  auto tv3 = add(tv0, new Float(1.0));
  auto tv4 = add(tv3, tv1);
  auto tv5 = add(tv3, tv2);
  auto tv6 = add(tv4, tv5);

  fusion.addOutput(tv6);

  constexpr int x = 63, y = 33;

  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);

  at::Tensor t0 = at::randn({x, y}, options);
  at::Tensor t1 = at::randn({x, y}, options);
  at::Tensor t2 = at::randn({x, y}, options);

  torch::jit::fuser::cuda::FusionExecutor fe;
  //fuser::cuda::scheduleFusion(&fusion, {t0, t1, t2});
  tv6->merge(0);
  tv6->split(0, 128);
  tv6->split(0, 4);

  tv6->axis(0)->parallelize(ParallelType::BIDx);

  tv0->computeAt(tv6, 1);
  tv1->computeAt(tv6, 1);
  tv2->computeAt(tv6, 1);

  tv3->axis(-2)->parallelize(ParallelType::Unroll);
  tv3->axis(-1)->parallelize(ParallelType::TIDx);
  tv4->axis(-2)->parallelize(ParallelType::Unroll);
  tv4->axis(-1)->parallelize(ParallelType::TIDx);
  tv5->axis(-2)->parallelize(ParallelType::Unroll);
  tv5->axis(-1)->parallelize(ParallelType::TIDx);

  fe.compileFusion(&fusion);
  auto outputs = fe.runFusion({t0, t1, t2});

  auto t3 = t0.add(1.0);
  auto t4 = t3.add(t1);
  auto t5 = t3.add(t2);
  auto t6 = t4.add(t5);

  TORCH_CHECK(t4.allclose(outputs[0]));
}

generated code (where T3 is used before it's created and filled):

__global__ void kernel1(Tensor<float, 2> T0, Tensor<float, 2> T1, Tensor<float, 2> T2, Tensor<float, 2> T6){
  float T4[4];
  if ( ( ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + ( 4 - 1 ) ) / T6.size[1] ) < T6.size[0] ) && ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + ( 4 - 1 ) ) % T6.size[1] ) < T6.size[1] ) ) ) {
    for(size_t i47 = 0; i47 < 4; ++i47 ) {
      T4[ i47 ]
         = T3[ i47 ]
         + T1[ ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + i47 ) / T6.size[1] ) * T1.stride[0] ) + ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + i47 ) % T6.size[1] ) * T1.stride[1] ) ];
    }
  } else {
    for(size_t i47 = 0; i47 < 4; ++i47 ) {
      if ( ( ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + i47 ) / T6.size[1] ) < T6.size[0] ) && ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + i47 ) % T6.size[1] ) < T6.size[1] ) ) ) {
        T4[ i47 ]
           = T3[ i47 ]
           + T1[ ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + i47 ) / T6.size[1] ) * T1.stride[0] ) + ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + i47 ) % T6.size[1] ) * T1.stride[1] ) ];
      }
    } 
  } 
  float T3[4];
  if ( ( ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + ( 4 - 1 ) ) / T6.size[1] ) < T6.size[0] ) && ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + ( 4 - 1 ) ) % T6.size[1] ) < T6.size[1] ) ) ) {
    for(size_t i48 = 0; i48 < 4; ++i48 ) {
      T3[ i48 ]
         = T0[ ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + i48 ) / T6.size[1] ) * T0.stride[0] ) + ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + i48 ) % T6.size[1] ) * T0.stride[1] ) ]
         + float(1);
    }
  } else {
    for(size_t i48 = 0; i48 < 4; ++i48 ) {
      if ( ( ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + i48 ) / T6.size[1] ) < T6.size[0] ) && ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + i48 ) % T6.size[1] ) < T6.size[1] ) ) ) {
        T3[ i48 ]
           = T0[ ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + i48 ) / T6.size[1] ) * T0.stride[0] ) + ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + i48 ) % T6.size[1] ) * T0.stride[1] ) ]
           + float(1);
      }
    }
  }
  float T5[4];
  if ( ( ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + ( 4 - 1 ) ) / T6.size[1] ) < T6.size[0] ) && ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + ( 4 - 1 ) ) % T6.size[1] ) < T6.size[1] ) ) ) {
    for(size_t i49 = 0; i49 < 4; ++i49 ) {
      T5[ i49 ]
         = T3[ i49 ]
         + T2[ ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + i49 ) / T6.size[1] ) * T2.stride[0] ) + ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + i49 ) % T6.size[1] ) * T2.stride[1] ) ];
    }
  } else {
    for(size_t i49 = 0; i49 < 4; ++i49 ) {
      if ( ( ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + i49 ) / T6.size[1] ) < T6.size[0] ) && ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + i49 ) % T6.size[1] ) < T6.size[1] ) ) ) {
        T5[ i49 ]
           = T3[ i49 ]
           + T2[ ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + i49 ) / T6.size[1] ) * T2.stride[0] ) + ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + i49 ) % T6.size[1] ) * T2.stride[1] ) ];
      }
    }
  }
  for(size_t i50 = 0; i50 < 128; ++i50 ) {
    for(size_t i51 = 0; i51 < 4; ++i51 ) {
      if ( ( ( ( ( ( blockIdx.x * 128 ) + i50 ) * 4 ) + i51 ) < ( T6.size[0] * T6.size[1] ) ) ) {
        T6[ ( ( ( ( blockIdx.x * 128 ) + i50 ) * 4 ) + i51 ) ]
           = T4[ ( i50 * 4 ) + i51 ]
           + T5[ ( i50 * 4 ) + i51 ];
      }
    }
  }
}

To Reproduce

I have pushed the repro here:
https://github.com/csarofeen/pytorch/tree/failed_branching_in_scheduling

you should just build it and wrong the cpp test with:
CUDA_VISIBLE_DEVICES=1 PYTORCH_CUDA_FUSER_DEBUG=1 ./test_jit --gtest_filter="*GPU_FusionBranches*"

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions