-
Notifications
You must be signed in to change notification settings - Fork 68
[Bug] FP64 reduce #91
Copy link
Copy link
Closed
Description
See the snippet below:
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <hidet/runtime/cuda_context.h>
#include <hidet/runtime/cpu_context.h>
typedef float tfloat32_t;
#define __float_to_tf32(x) (x)
/*
Task(
name: reduce_min
parameters:
x: tensor(float64, [1])
y: tensor(float64, [])
inputs: [x]
outputs: [y]
computations:
x: tensor(float64, [1])
y: float64[] where y[] = reduce([1], (v) => x[v], minreduce)
attributes: {dims: [0], keep_dim: 0, reduce_type: min, accumulate_dtype: float32}
)
*/
extern "C" {
__global__ void __launch_bounds__(32) hidet_reduce_min_grid(double * __restrict__ x, double * __restrict__ y) {
// label: reduce schedule
float rv = 3.4028234663852886e+38f;
if ((int)threadIdx.x < 1) {
rv = ((float)(fmin(rv, x[(int)threadIdx.x])));
}
int32_t mask = __activemask();
rv = fminf(rv, __shfl_down_sync(mask, rv, 16, 32));
rv = fminf(rv, __shfl_down_sync(mask, rv, 8, 32));
rv = fminf(rv, __shfl_down_sync(mask, rv, 4, 32));
rv = fminf(rv, __shfl_down_sync(mask, rv, 2, 32));
rv = fminf(rv, __shfl_down_sync(mask, rv, 1, 32));
rv = __shfl_sync(mask, rv, 0, 32);
rv = rv;
if ((int)threadIdx.x < 1) {
if ((int)threadIdx.x == 0) {
y[0] = ((double)(rv));
}
}
}
__host__ void hidet_reduce_min(int32_t num_args, int32_t * __restrict__ arg_types, void* * __restrict__ args) {
assert(((void)"Expect 2 arguments", (num_args == 2)));
assert(((void)"The 0-th argument should be tensor(float64, [1])", (arg_types[0] == 3)));
assert(((void)"The 1-th argument should be tensor(float64, [])", (arg_types[1] == 3)));
hidet_reduce_min_grid<<<dim3(1, 1, 1), dim3(32, 1, 1), 0, (cudaStream_t)get_cuda_stream()>>>(((double*)(args[0])), ((double*)(args[1])));
}
}
Several problems here:
- Why use fp32 accumulation in a fp64 operator? I guess we should not use fp32 for the default value of accumulation type (what if the dtype is interger?)
- fp32 accumulator further leads to the call to
fmin(float, double), which only exists as a host function, causing a compile error.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels