Skip to content

[Bug] FP64 reduce #91

@soodoshll

Description

@soodoshll

See the snippet below:

#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <hidet/runtime/cuda_context.h>
#include <hidet/runtime/cpu_context.h>
typedef float tfloat32_t;
#define __float_to_tf32(x) (x)
/*
Task(
  name: reduce_min
  parameters:
    x: tensor(float64, [1])
    y: tensor(float64, [])
  inputs: [x]
  outputs: [y]
  computations:
    x: tensor(float64, [1])
    y: float64[] where y[] = reduce([1], (v) => x[v], minreduce)
  attributes: {dims: [0], keep_dim: 0, reduce_type: min, accumulate_dtype: float32}
)
*/
extern "C" {

__global__ void __launch_bounds__(32) hidet_reduce_min_grid(double * __restrict__ x, double * __restrict__ y) {
  // label: reduce schedule
  float rv = 3.4028234663852886e+38f;
  if ((int)threadIdx.x < 1) {
    rv = ((float)(fmin(rv, x[(int)threadIdx.x])));
  }
  int32_t mask = __activemask();
  rv = fminf(rv, __shfl_down_sync(mask, rv, 16, 32));
  rv = fminf(rv, __shfl_down_sync(mask, rv, 8, 32));
  rv = fminf(rv, __shfl_down_sync(mask, rv, 4, 32));
  rv = fminf(rv, __shfl_down_sync(mask, rv, 2, 32));
  rv = fminf(rv, __shfl_down_sync(mask, rv, 1, 32));
  rv = __shfl_sync(mask, rv, 0, 32);
  rv = rv;
  if ((int)threadIdx.x < 1) {
    if ((int)threadIdx.x == 0) {
      y[0] = ((double)(rv));
    }
  }
}

__host__ void hidet_reduce_min(int32_t num_args, int32_t * __restrict__ arg_types, void* * __restrict__ args) {
  assert(((void)"Expect 2 arguments", (num_args == 2)));
  assert(((void)"The 0-th argument should be tensor(float64, [1])", (arg_types[0] == 3)));
  assert(((void)"The 1-th argument should be tensor(float64, [])", (arg_types[1] == 3)));
  hidet_reduce_min_grid<<<dim3(1, 1, 1), dim3(32, 1, 1), 0, (cudaStream_t)get_cuda_stream()>>>(((double*)(args[0])), ((double*)(args[1])));
}

}

Several problems here:

  • Why use fp32 accumulation in a fp64 operator? I guess we should not use fp32 for the default value of accumulation type (what if the dtype is interger?)
  • fp32 accumulator further leads to the call to fmin(float, double), which only exists as a host function, causing a compile error.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions