Skip to content

Proper predication of grid-reduced tensors #219

@naoyam

Description

@naoyam

Currently, using output tensors of grid reductions is not supported. Expressions can be created using such tensors, but correctness is not guaranteed. Because of this, for example, casting reduction outputs is not supported.

Example:

  TensorView* tv1 = reductionOp(BinaryOpType::Add, {1}, new Float(0), tv0);
  TensorView* tv2 = unaryOp(UnaryOpType::Neg, tv1);
  tv1->split(1, bdimx);
  tv1->split(1, gdimx);

  TensorView* tv1_rf = tv1->rFactor({1});

  tv1->computeAt(tv2, -1);

  tv1->axis(0)->parallelize(ParallelType::BIDy);
  tv1_rf->axis(0)->parallelize(ParallelType::BIDy);
  tv2->axis(0)->parallelize(ParallelType::BIDy);
  tv1->axis(-2)->parallelize(ParallelType::BIDx);
  tv1_rf->axis(-2)->parallelize(ParallelType::BIDx);
  tv1->axis(-1)->parallelize(ParallelType::TIDx);
  tv1_rf->axis(-1)->parallelize(ParallelType::TIDx);

See testGPU_FusionGridReduction7 in https://github.com/naoyam/pytorch/blob/fix-gridreduction-predicate/test/cpp/jit/test_gpu.cpp for the full test case. The test case should complete successfully when gdimx is 1 since there will be only one thread block along the gridDim.x dimension. Anything larger than that should almost always fail.

The generated code would look like:

__global__ void kernel1(Tensor<float, 2> T0, Tensor<float, 2> T3, Tensor<float, 1> T2, Tensor<float, 1> T7, Tensor<int64_t, 1> T8){
  __shared__ float shared_mem[1024];
  float T4[1];
  if ( ( ( ( ( ( 0 * 1 ) + blockIdx.x ) * 128 ) + threadIdx.x ) < T3.size[1] ) ) {
    T4[ 0 ]
       = float(0);
  }
  for(size_t i29 = 0; i29 < ( ceilDiv(( ceilDiv(T3.size[1], 128) ), 1) ); ++i29 ) {
    if ( ( ( ( ( ( i29 * 1 ) + blockIdx.x ) * 128 ) + threadIdx.x ) < T3.size[1] ) ) {
      T4[ 0 ]
         = T4[ 0 ]
         + T0[ ( blockIdx.y * T0.stride[0] ) + ( ( ( ( ( i29 * 1 ) + blockIdx.x ) * 128 ) + threadIdx.x ) * T0.stride[1] ) ];
    }
  }
  float T1[1];
  T1[ 0 ]
     = float(0);
  float block_result;
  blockReduce< true, false, false > ( block_result, T4[ 0 ], reduction_add_float, threadIdx, blockDim, reinterpret_cast<float*>(shared_mem));
  // Allocate global tensor float T7[( T0.size[0] * 1 )];
  // Allocate global tensor int64_t T8[T0.size[0]];
  reduction::gridReduce< true, false, false, false, true, true > ( T1[ 0 ], block_result, reduction_add_float, &T7[0], T8, reinterpret_cast<float*>(shared_mem));
  if ( ( ( blockIdx.x == 0 ) && ( threadIdx.x == 0 ) ) ) {
    T2[ ( blockIdx.y * T2.stride[0] ) ]
       = -T1[ 0 ];
  }

Notice that the Neg unary op is predicated with blockIdx.x == 0. That's not necessarily the right thread block holding the valid output of the gridReduce, thus the Neg unary operation may use invalid values.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions