Skip to content

reduction with fp16 cast generates wrong kernel indexing  #362

@jjsjann123

Description

@jjsjann123

🐛 Bug

The issue was discovered while working on Kevin's issue #357 and PR #361. For the given model below:

  TensorView* tv0 = makeDummyTensor(3, DataType::Half);                              
  fusion.addInput(tv0);                                                              
                                                                                     
  auto tv1 = castOp(DataType::Float, tv0);                                           
  auto tv2 = add(tv1, new Float(1.0));                                               
  auto tv3 = sum(tv2, {2});                                                          
  auto tv4 = castOp(DataType::Half, tv3);                                            
                                                                                     
  fusion.addOutput(tv4);

  const auto options =                                                               
      at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);                     
  at::Tensor input = at::randn({8, 8, 16}, options);                                 
                                                                                     
  // Apply reduction heuristic                                                       
  const at::ArrayRef<c10::IValue> inputs({input});                                   
                                                                                     
  TORCH_CHECK(                                                                       
      cuda::scheduleReduction(&fusion, inputs, tv3),                                 
      "Reduction schedule was not generated!");                                      
  GpuLower gpulw(&fusion);                                                           
  std::stringstream kernel;                                                          
  gpulw.printKernel(kernel);

our code lowering generated a kernel with wrong indexing.
Here's the printed fusion after transformation

T1[ iS45{( ceilDiv(( i1 * i3 ), 64) )}, iS46{64}, iS50{4}, iS49{( ceilDiv(( ceilDiv(i5, 4) ), 4) )}, iS48{4} ] compute_at( T2, 5 )
   = __half2float(T0[ iS0{i1}, iS1{i3}, iS2{i5} ]);
T2[ iS38{( ceilDiv(( i1 * i3 ), 64) )}, iS39{64}, iS43{4}, iS42{( ceilDiv(( ceilDiv(i5, 4) ), 4) )}, iS41{4} ] compute_at( T5, 5 )
   = T1[ iS45{( ceilDiv(( i1 * i3 ), 64) )}, iS46{64}, iS50{4}, iS49{( ceilDiv(( ceilDiv(i5, 4) ), 4) )}, iS48{4} ] compute_at( T2, 5 )
   + float(1);
T5[ iS25{( ceilDiv(( i1 * i3 ), 64) )}, iS26{64}, iS30{4}rf, rS29{( ceilDiv(( ceilDiv(i5, 4) ), 4) )}rf, rU28{4}rf ] compute_at( T3, 3 ) = reduction( T2[ iS38{( ceilDiv(( i1 * i3 ), 64) )}, iS39{64}, iS43{4}, iS42{( ceilDiv(( ceilDiv(i5, 4) ), 4) )}, iS41{4} ] compute_at( T5, 5 ), op = add, initial value = float(0) )
T3[ iblockIdx.x34{gridDim.x}, ithreadIdx.y35{64}, rthreadIdx.x36{4} ] = reduction( T5[ iS25{( ceilDiv(( i1 * i3 ), 64) )}, iS26{64}, iS30{4}rf, rS29{( ceilDiv(( ceilDiv(i5, 4) ), 4) )}rf, rU28{4}rf ] compute_at( T3, 3 ), op = add, initial value = float(0) )
T4[ iS12{i1}, iS13{i3} ]
   = __float2half(T3[ iblockIdx.x34{gridDim.x}, ithreadIdx.y35{64}, rthreadIdx.x36{4} ]);

Here's the lowered code:

__global__ void CUDAGeneratedKernel(Tensor<__half, 3> T0, Tensor<__half, 2> T4){
  alignas(4) extern __shared__ char array[];
  void* shared_mem = array;
  float T3[1];
  if ( ( ( ( blockIdx.x * 64 ) + threadIdx.y ) < ( T0.size[0] * T0.size[1] ) ) ) {
    T3[ 0 ]
       = float(0);
  }
  float T5[1];
  if ( ( ( ( blockIdx.x * 64 ) + threadIdx.y ) < ( T0.size[0] * T0.size[1] ) ) ) {
    T5[ 0 ]
       = float(0);
  }
  for(size_t i11 = 0; i11 < ( ceilDiv(( ceilDiv(T0.size[2], 4) ), 4) ); ++i11 ) {
    float T2[4];
    float T1[4];
    if ( ( ( ( ( ( blockIdx.x * 64 ) + threadIdx.y ) < ( T0.size[0] * T0.size[1] ) ) && ( ( ( ( ( i11 * 4 ) + threadIdx.x ) * 4 ) + ( 4 - 1 ) ) < T0.size[2] ) ) && ( ( ( ( blockIdx.x * 64 ) + threadIdx.y ) / T0.size[1] ) < T0.size[0] ) ) ) {
      for(size_t i13 = 0; i13 < 4; ++i13 ) {
        T1[ i13 ]
           = __half2float(T0[ ( ( ( ( blockIdx.x * 64 ) + threadIdx.y ) / T0.size[1] ) * T0.stride[0] ) + ( ( ( ( blockIdx.x * 64 ) + threadIdx.y ) % T0.size[1] ) * T0.stride[1] ) + ( ( ( ( ( i11 * 4 ) + threadIdx.x ) * 4 ) + i13 ) * T0.stride[2] ) ]);
        T2[ i13 ]
           = T1[ i13 ]
           + float(1);
        T5[ 0 ]
           = T5[ 0 ]
           + T2[ i13 ];
      }
    } else {
      for(size_t i13 = 0; i13 < 4; ++i13 ) {
        if ( ( ( ( ( ( ( blockIdx.x * 64 ) + threadIdx.y ) / T0.size[1] ) < T0.size[0] ) && ( ( ( ( blockIdx.x * 64 ) + threadIdx.y ) % T0.size[1] ) < T0.size[1] ) ) && ( ( ( ( ( i11 * 4 ) + threadIdx.x ) * 4 ) + i13 ) < T0.size[2] ) ) ) {
          T1[ i13 ]
             = __half2float(T0[ ( ( ( ( blockIdx.x * 64 ) + threadIdx.y ) / T0.size[1] ) * T0.stride[0] ) + ( ( ( ( blockIdx.x * 64 ) + threadIdx.y ) % T0.size[1] ) * T0.stride[1] ) + ( ( ( ( ( i11 * 4 ) + threadIdx.x ) * 4 ) + i13 ) * T0.stride[2] ) ]);
        }
        if ( ( ( ( ( blockIdx.x * 64 ) + threadIdx.y ) < ( T0.size[0] * T0.size[1] ) ) && ( ( ( ( ( i11 * 4 ) + threadIdx.x ) * 4 ) + i13 ) < T0.size[2] ) ) ) {
          T2[ i13 ]
             = T1[ i13 ]
             + float(1);
        }
        if ( ( ( ( ( blockIdx.x * 64 ) + threadIdx.y ) < ( T0.size[0] * T0.size[1] ) ) && ( ( ( ( ( i11 * 4 ) + threadIdx.x ) * 4 ) + i13 ) < T0.size[2] ) ) ) {
          T5[ 0 ]
             = T5[ 0 ]
             + T2[ i13 ];
        }
      }
    }
  }
  if ( ( ( ( blockIdx.x * 64 ) + threadIdx.y ) < ( T0.size[0] * T0.size[1] ) ) ) {
    blockReduce< true, false, false > ( T3[ 0 ], T5[ 0 ], reduction_add_float, threadIdx, blockDim, static_cast<float*>(shared_mem));
  }
  for(size_t i42 = 0; i42 < T0.size[0]; ++i42 ) {
    for(size_t i43 = 0; i43 < T0.size[1]; ++i43 ) {
      if ( ( threadIdx.x == 0 ) ) {
        T4[ ( i42 * T4.stride[0] ) + i43 ]
           = __float2half(T3[ ( i42 * T0.size[1] ) + i43 ]);
      }
    }
  }
}

Indexing for T3 & T4 looks wrong.

To Reproduce

The repro I have has been pushed to csarofeen/reduction_half_repro
You should be able to build it and run

./test_jit --gtest_filter="*GPU_FusionReductionHalfRepro*"

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions