pkd3 vs pkd3_pln3

Created Diff never expires
51 removals
188 lines
76 additions
204 lines
template <typename T>
template <typename T>
__global__ void jpeg_compression_distortion_pkd3_hip_tensor(T *srcPtr,
__global__ void jpeg_compression_distortion_pkd3_pln3_hip_tensor( T *srcPtr,
uint2 srcStridesNH,
uint2 srcStridesNH,
T *dstPtr,
T *dstPtr,
uint2 dstStridesNH,
uint3 dstStridesNCH,
RpptROIPtr roiTensorPtrSrc,
RpptROIPtr roiTensorPtrSrc,
int *tableY,
int *tableY,
int *tableCbCr,
int *tableCbCr,
float qScale)
float qScale)
{
{
int id_x = (hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x) * 8;
int id_x = (hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x) * 8;
int id_y = hipBlockIdx_y * hipBlockDim_y + hipThreadIdx_y;
int id_y = hipBlockIdx_y * hipBlockDim_y + hipThreadIdx_y;
int id_z = hipBlockIdx_z * hipBlockDim_z + hipThreadIdx_z;
int id_z = hipBlockIdx_z * hipBlockDim_z + hipThreadIdx_z;


int hipThreadIdx_x8 = hipThreadIdx_x * 8;
int hipThreadIdx_x8 = hipThreadIdx_x * 8;
int hipThreadIdx_x4 = hipThreadIdx_x * 4;
int hipThreadIdx_x4 = hipThreadIdx_x * 4;


int alignedWidth = (roiTensorPtrSrc[id_z].xywhROI.roiWidth + 15) & ~15;
int alignedWidth = (roiTensorPtrSrc[id_z].xywhROI.roiWidth + 15) & ~15;
int alignedHeight = (roiTensorPtrSrc[id_z].xywhROI.roiHeight + 15) & ~15;
int alignedHeight = (roiTensorPtrSrc[id_z].xywhROI.roiHeight + 15) & ~15;


// Boundary checks
// Boundary checks
if((id_y >= alignedHeight) || (id_x >= alignedWidth))
if((id_y >= alignedHeight) || (id_x >= alignedWidth))
return;
return;


// ROI parameters
// ROI parameters
int roiX = roiTensorPtrSrc[id_z].xywhROI.xy.x;
int roiX = roiTensorPtrSrc[id_z].xywhROI.xy.x;
int roiY = roiTensorPtrSrc[id_z].xywhROI.xy.y;
int roiY = roiTensorPtrSrc[id_z].xywhROI.xy.y;
int roiWidth = roiTensorPtrSrc[id_z].xywhROI.roiWidth;
int roiWidth = roiTensorPtrSrc[id_z].xywhROI.roiWidth;
int roiHeight = roiTensorPtrSrc[id_z].xywhROI.roiHeight;
int roiHeight = roiTensorPtrSrc[id_z].xywhROI.roiHeight;


__shared__ float src_smem[48][128];
// Shared memory declaration
int3 hipThreadIdx_y_channel = {hipThreadIdx_y, hipThreadIdx_y + 16, hipThreadIdx_y + 32};
__shared__ float src_smem[48][128]; // Assuming 48 rows (aligned height for 3 channels)
int3 hipThreadIdx_y_channel;
hipThreadIdx_y_channel.x = hipThreadIdx_y;
hipThreadIdx_y_channel.y = hipThreadIdx_y + 16;
hipThreadIdx_y_channel.z = hipThreadIdx_y + 32;


float *src_smem_channel[3] = {
float *src_smem_channel[3];
&src_smem[hipThreadIdx_y_channel.x][hipThreadIdx_x8],
src_smem_channel[0] = &src_smem[hipThreadIdx_y_channel.x][hipThreadIdx_x8];
&src_smem[hipThreadIdx_y_channel.y][hipThreadIdx_x8],
src_smem_channel[1] = &src_smem[hipThreadIdx_y_channel.y][hipThreadIdx_x8];
&src_smem[hipThreadIdx_y_channel.z][hipThreadIdx_x8]
src_smem_channel[2] = &src_smem[hipThreadIdx_y_channel.z][hipThreadIdx_x8];
};


// ----------- Step 1: Load from Global Memory to Shared Memory -----------
// ----------- Step 1: Load from Global Memory to Shared Memory -----------
int srcIdx;
int srcIdx;
int dstIdx = (id_z * dstStridesNH.x) + (id_y * dstStridesNH.y) + id_x * 3;
uint3 dstIdx;
dstIdx.x = (id_z * dstStridesNCH.x) + (id_y * dstStridesNCH.z) + id_x;
dstIdx.y = dstIdx.x + dstStridesNCH.y;
dstIdx.z = dstIdx.y + dstStridesNCH.y;


// Check if we need special handling for image edges
// Check if we need special handling for image edges
if(id_y < roiHeight)
if(id_y < roiHeight)
srcIdx = (id_z * srcStridesNH.x) + ((id_y + roiY) * srcStridesNH.y) + ((id_x + roiX) * 3);
srcIdx = (id_z * srcStridesNH.x) + ((id_y + roiY) * srcStridesNH.y) + ((id_x + roiX) * 3);
else // All out-of-bounds threads use the last valid row
else // All out-of-bounds threads use the last valid row
srcIdx = (id_z * srcStridesNH.x) + ((roiHeight - 1 + roiY) * srcStridesNH.y) + ((id_x + roiX) * 3);
srcIdx = (id_z * srcStridesNH.x) + ((roiHeight - 1 + roiY) * srcStridesNH.y) + ((id_x + roiX) * 3);


bool isEdge = ((id_x + 8) > roiWidth) && (id_x < alignedWidth);
bool isEdge = ((id_x + 8) > roiWidth) && (id_x < alignedWidth);

if(!isEdge)
if (!isEdge)
{
rpp_hip_load24_pkd3_to_float24_pln3(srcPtr + srcIdx, src_smem_channel);
rpp_hip_load24_pkd3_to_float24_pln3(srcPtr + srcIdx, src_smem_channel);
}
else
else
{
{
// Partial block load with edge pixel replication
int validPixels = roiWidth - id_x;
int validPixels = roiWidth - id_x;
if (validPixels > 0)
// Load valid pixels (only if id_x is within valid range)
if(validPixels > 0)
{
{
for (int i = 0, idx = srcIdx; i < validPixels; i++, idx += 3)
for(int i = 0, idx = srcIdx; i < validPixels; i++, idx += 3)
{
{
src_smem[hipThreadIdx_y_channel.x][hipThreadIdx_x8 + i] = srcPtr[idx];
src_smem[hipThreadIdx_y_channel.x][hipThreadIdx_x8 + i] = srcPtr[idx];
src_smem[hipThreadIdx_y_channel.y][hipThreadIdx_x8 + i] = srcPtr[idx + 1];
src_smem[hipThreadIdx_y_channel.y][hipThreadIdx_x8 + i] = srcPtr[idx + 1];
src_smem[hipThreadIdx_y_channel.z][hipThreadIdx_x8 + i] = srcPtr[idx + 2];
src_smem[hipThreadIdx_y_channel.z][hipThreadIdx_x8 + i] = srcPtr[idx + 2];
}
}
}
}

// Pad 16 pixels by duplicating the last valid pixel
int lastValidIdx = srcIdx + ((validPixels - 1) * 3);
int lastValidIdx = srcIdx + ((validPixels - 1) * 3);
for (int i = validPixels; i < min(validPixels + 16, 8); i++)
for(int i = validPixels; i < min(validPixels + 16, 8); i++)
{
{
src_smem[hipThreadIdx_y_channel.x][hipThreadIdx_x8 + i] = srcPtr[lastValidIdx];
src_smem[hipThreadIdx_y_channel.x][hipThreadIdx_x8 + i] = srcPtr[lastValidIdx];
src_smem[hipThreadIdx_y_channel.y][hipThreadIdx_x8 + i] = srcPtr[lastValidIdx + 1];
src_smem[hipThreadIdx_y_channel.y][hipThreadIdx_x8 + i] = srcPtr[lastValidIdx + 1];
src_smem[hipThreadIdx_y_channel.z][hipThreadIdx_x8 + i] = srcPtr[lastValidIdx + 2];
src_smem[hipThreadIdx_y_channel.z][hipThreadIdx_x8 + i] = srcPtr[lastValidIdx + 2];
}
}
}
}
__syncthreads();
__syncthreads();


// ----------- Step 2: RGB to YCbCr Conversion -----------
// ----------- Step 2: RGB to YCbCr Conversion -----------
d_float8 y_f8;
d_float8 y_f8;
d_float24 rgb_f24;
d_float24 rgb_f24;
rgb_f24.f8[0] = *((d_float8*)&src_smem[hipThreadIdx_y][hipThreadIdx_x8]);
rgb_f24.f8[0] = *((d_float8*)&src_smem[hipThreadIdx_y][hipThreadIdx_x8]);
rgb_f24.f8[1] = *((d_float8*)&src_smem[hipThreadIdx_y + 16][hipThreadIdx_x8]);
rgb_f24.f8[1] = *((d_float8*)&src_smem[hipThreadIdx_y + 16][hipThreadIdx_x8]);
rgb_f24.f8[2] = *((d_float8*)&src_smem[hipThreadIdx_y + 32][hipThreadIdx_x8]);
rgb_f24.f8[2] = *((d_float8*)&src_smem[hipThreadIdx_y + 32][hipThreadIdx_x8]);

int cbcrY = hipThreadIdx_y * 2;
int cbcrY = hipThreadIdx_y * 2;
y_hip_compute(srcPtr, rgb_f24, &y_f8);
y_hip_compute(srcPtr, rgb_f24, &y_f8);
__syncthreads();
__syncthreads();


// ----------- Step 3: Downsample CbCr -----------
// ----------- Step 3: Downsample CbCr -----------
if (hipThreadIdx_y < 8)
if(hipThreadIdx_y < 8)
{
{
float4 cb_f4, cr_f4;
float4 cb_f4, cr_f4;
downsample_cbcr_hip_compute(
// Downsample RGB and convert to CbCr
(d_float8*)&src_smem[cbcrY][hipThreadIdx_x8],
downsample_cbcr_hip_compute((d_float8*)&src_smem[cbcrY][hipThreadIdx_x8], (d_float8*)&src_smem[cbcrY + 1][hipThreadIdx_x8], (d_float8*)&src_smem[cbcrY + 16][hipThreadIdx_x8], (d_float8*)&src_smem[cbcrY + 17][hipThreadIdx_x8], (d_float8*)&src_smem[cbcrY + 32][hipThreadIdx_x8], (d_float8*)&src_smem[cbcrY + 33][hipThreadIdx_x8],&cb_f4, &cr_f4);
(d_float8*)&src_smem[cbcrY + 1][hipThreadIdx_x8],
// Store Y and downsampled CbCr
(d_float8*)&src_smem[cbcrY + 16][hipThreadIdx_x8],
(d_float8*)&src_smem[cbcrY + 17][hipThreadIdx_x8],
(d_float8*)&src_smem[cbcrY + 32][hipThreadIdx_x8],
(d_float8*)&src_smem[cbcrY + 33][hipThreadIdx_x8],
&cb_f4, &cr_f4);

*(float4*)&src_smem[hipThreadIdx_y_channel.y][hipThreadIdx_x4] = cb_f4;
*(float4*)&src_smem[hipThreadIdx_y_channel.y][hipThreadIdx_x4] = cb_f4;
// Storing Cr below Cb (8 x 64)
*(float4*)&src_smem[8 + hipThreadIdx_y_channel.y][hipThreadIdx_x4] = cr_f4;
*(float4*)&src_smem[8 + hipThreadIdx_y_channel.y][hipThreadIdx_x4] = cr_f4;
}
}


*(d_float8*)&src_smem[hipThreadIdx_y_channel.x][hipThreadIdx_x8] = y_f8;
*(d_float8*)&src_smem[hipThreadIdx_y_channel.x][hipThreadIdx_x8] = y_f8;
__syncthreads();
__syncthreads();


// ----------- Step 4: Clamp + Forward DCT -----------
// ----------- Step 4: Clamp + Forward DCT -----------
clamp_range((float*)&src_smem[hipThreadIdx_y_channel.x][hipThreadIdx_x8]);
clamp_range((float*)&src_smem[hipThreadIdx_y_channel.x][hipThreadIdx_x8]);
clamp_range((float*)&src_smem[hipThreadIdx_y_channel.y][hipThreadIdx_x8]);
clamp_range((float*)&src_smem[hipThreadIdx_y_channel.y][hipThreadIdx_x8]);
__syncthreads();
__syncthreads();


// Doing -128 as part of DCT,
// 1D row wise FWD DCT for Y Cb and Cr channels
dct_fwd_8x8_1d(&src_smem[hipThreadIdx_y_channel.x][hipThreadIdx_x8], true);
dct_fwd_8x8_1d(&src_smem[hipThreadIdx_y_channel.x][hipThreadIdx_x8], true);
dct_fwd_8x8_1d(&src_smem[hipThreadIdx_y_channel.y][hipThreadIdx_x8], true);
dct_fwd_8x8_1d(&src_smem[hipThreadIdx_y_channel.y][hipThreadIdx_x8], true);
__syncthreads();
__syncthreads();


// ----------- Step5 Column-wise DCT -----------
// // ----------- Step5 Column-wise DCT -----------
int col = (hipThreadIdx_x * 16) + hipThreadIdx_y;
int col = (hipThreadIdx_x * 16) + hipThreadIdx_y;
if ((col < 128) && (col < alignedWidth))
// Process all 128 columns
if((col < 128) && (col < alignedWidth))
{
{
// Load column into temporary array
float colVec[32];
float colVec[32];
for(int i = 0; i < 32; i++) colVec[i] = src_smem[i][col];


dct_fwd_8x8_1d(&colVec[0], false);
for(int i = 0; i < 32; i++)
dct_fwd_8x8_1d(&colVec[8], false);
colVec[i] = src_smem[i][col];

dct_fwd_8x8_1d(&colVec[0], false);
dct_fwd_8x8_1d(&colVec[8], false);
dct_fwd_8x8_1d(&colVec[16], false);
dct_fwd_8x8_1d(&colVec[16], false);
dct_fwd_8x8_1d(&colVec[24], false);
dct_fwd_8x8_1d(&colVec[24], false);


for(int i = 0; i < 32; i++) src_smem[i][col] = colVec[i];
for(int i = 0; i < 32; i++)
src_smem[i][col] = colVec[i];
}
}
__syncthreads();
__syncthreads();


// ----------- Step 6: Quantization -----------
// ----------- Step 6: Quantization -----------
quantize(&src_smem[hipThreadIdx_y_channel.x][hipThreadIdx_x8], &tableY[(hipThreadIdx_y % 8) * 8]);
quantize(&src_smem[hipThreadIdx_y_channel.x][hipThreadIdx_x8], &tableY[(hipThreadIdx_y % 8) * 8]);
quantize(&src_smem[hipThreadIdx_y_channel.y][hipThreadIdx_x8], &tableCbCr[(hipThreadIdx_y % 8) * 8]);
quantize(&src_smem[hipThreadIdx_y_channel.y][hipThreadIdx_x8], &tableCbCr[(hipThreadIdx_y % 8) * 8]);
__syncthreads();
__syncthreads();


// ----------- Step 7: Inverse DCT -----------
// ----------- Step 7: Inverse DCT -----------
// 1D column wise IDCT for Y Cb and Cr channels
if((col < 128) && (col < alignedWidth))
if((col < 128) && (col < alignedWidth))
{
{
// Load column into temporary array
float colVec[32];
float colVec[32];
for(int i = 0; i < 32; i++) colVec[i] = src_smem[i][col];


dct_inv_8x8_1d(&colVec[0], false);
for(int i = 0; i < 32; i++)
dct_inv_8x8_1d(&colVec[8], false);
colVec[i] = src_smem[i][col];

dct_inv_8x8_1d(&colVec[0], false);
dct_inv_8x8_1d(&colVec[8], false);
dct_inv_8x8_1d(&colVec[16], false);
dct_inv_8x8_1d(&colVec[16], false);
dct_inv_8x8_1d(&colVec[24], false);
dct_inv_8x8_1d(&colVec[24], false);


for(int i = 0; i < 32; i++) src_smem[i][col] = colVec[i];
for(int i = 0; i < 32; i++)
src_smem[i][col] = colVec[i];
}
}
__syncthreads();
__syncthreads();


// 1D row wise IDCT for Y Cb and Cr channels
// Adding back 128 as part of INV DCT
dct_inv_8x8_1d(&src_smem[hipThreadIdx_y_channel.x][hipThreadIdx_x8], true);
dct_inv_8x8_1d(&src_smem[hipThreadIdx_y_channel.x][hipThreadIdx_x8], true);
dct_inv_8x8_1d(&src_smem[hipThreadIdx_y_channel.y][hipThreadIdx_x8], true);
dct_inv_8x8_1d(&src_smem[hipThreadIdx_y_channel.y][hipThreadIdx_x8], true);
__syncthreads();
__syncthreads();


// ----------- Step 8: Clamp & Upsample -----------
// ----------- Step 8: Clamp & Upsample -----------
clamp_range((float*)&src_smem[hipThreadIdx_y_channel.x][hipThreadIdx_x8]);
clamp_range((float*)&src_smem[hipThreadIdx_y_channel.x][hipThreadIdx_x8]);
clamp_range((float*)&src_smem[hipThreadIdx_y_channel.y][hipThreadIdx_x8]);
clamp_range((float*)&src_smem[hipThreadIdx_y_channel.y][hipThreadIdx_x8]);
__syncthreads();
__syncthreads();


// Vertical Upsampling
// Vertical Upsampling
float4 cb_f4, cr_f4;
float4 cb_f4, cr_f4;
cbcrY = hipThreadIdx_y / 2;
cbcrY = hipThreadIdx_y / 2;
cb_f4 = *(float4*)&src_smem[cbcrY + 16][hipThreadIdx_x4];
cb_f4 = *(float4*)&src_smem[cbcrY + 16][hipThreadIdx_x4];
cr_f4 = *(float4*)&src_smem[cbcrY + 24][hipThreadIdx_x4];
cr_f4 = *(float4*)&src_smem[cbcrY + 24][hipThreadIdx_x4];
__syncthreads();
__syncthreads();


// Convert back to RGB
// YCbCr to RGB
upsample_and_RGB_hip_compute(cb_f4, cr_f4, (d_float8*)&src_smem[hipThreadIdx_y_channel.x][hipThreadIdx_x8], (d_float8*)&src_smem[hipThreadIdx_y_channel.y][hipThreadIdx_x8], (d_float8*)&src_smem[hipThreadIdx_y_channel.z][hipThreadIdx_x8]);
upsample_and_RGB_hip_compute(cb_f4, cr_f4, (d_float8*)&src_smem[hipThreadIdx_y_channel.x][hipThreadIdx_x8], (d_float8*)&src_smem[hipThreadIdx_y_channel.y][hipThreadIdx_x8], (d_float8*)&src_smem[hipThreadIdx_y_channel.z][hipThreadIdx_x8]);
__syncthreads();
__syncthreads();


// ----------- Step 9: Final Clamp & Store -----------
// Clamp values and store results
rpp_hip_adjust_range(dstPtr, (d_float8*)&src_smem[hipThreadIdx_y_channel.x][hipThreadIdx_x8]);
rpp_hip_adjust_range(dstPtr, (d_float8*)&src_smem[hipThreadIdx_y_channel.x][hipThreadIdx_x8]);
rpp_hip_adjust_range(dstPtr, (d_float8*)&src_smem[hipThreadIdx_y_channel.y][hipThreadIdx_x8]);
rpp_hip_adjust_range(dstPtr, (d_float8*)&src_smem[hipThreadIdx_y_channel.y][hipThreadIdx_x8]);
rpp_hip_adjust_range(dstPtr, (d_float8*)&src_smem[hipThreadIdx_y_channel.z][hipThreadIdx_x8]);
rpp_hip_adjust_range(dstPtr, (d_float8*)&src_smem[hipThreadIdx_y_channel.z][hipThreadIdx_x8]);


// ----------- Step 9: Final Clamp & Store -----------
clamp_range(srcPtr, (float*)&src_smem[hipThreadIdx_y_channel.x][hipThreadIdx_x8]);
clamp_range(srcPtr, (float*)&src_smem[hipThreadIdx_y_channel.x][hipThreadIdx_x8]);
clamp_range(srcPtr, (float*)&src_smem[hipThreadIdx_y_channel.y][hipThreadIdx_x8]);
clamp_range(srcPtr, (float*)&src_smem[hipThreadIdx_y_channel.y][hipThreadIdx_x8]);
clamp_range(srcPtr, (float*)&src_smem[hipThreadIdx_y_channel.z][hipThreadIdx_x8]);
clamp_range(srcPtr, (float*)&src_smem[hipThreadIdx_y_channel.z][hipThreadIdx_x8]);
__syncthreads();
__syncthreads();


if((id_x < roiWidth) && (id_y < roiHeight))
if((id_x < roiWidth) && (id_y < roiHeight))
rpp_hip_pack_float24_pln3_and_store24_pkd3(dstPtr + dstIdx, src_smem_channel);
{
rpp_hip_pack_float8_and_store8(dstPtr + dstIdx.x, (d_float8 *)&src_smem[hipThreadIdx_y_channel.x][hipThreadIdx_x8]);
rpp_hip_pack_float8_and_store8(dstPtr + dstIdx.y, (d_float8 *)&src_smem[hipThreadIdx_y_channel.y][hipThreadIdx_x8]);
rpp_hip_pack_float8_and_store8(dstPtr + dstIdx.z, (d_float8 *)&src_smem[hipThreadIdx_y_channel.z][hipThreadIdx_x8]);
}
}
}