Skip to content

Commit 3fdd173

Browse files
pnunna93hubertlu-tw
authored andcommitted
replace n1 & n2 with M & N respectively in apex kernels
1 parent 7fa3203 commit 3fdd173

1 file changed

Lines changed: 39 additions & 39 deletions

File tree

aten/src/ATen/native/cuda/layer_norm_kernel.cu

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -795,7 +795,7 @@ void cuLoadWriteStridedInputs(
795795
const T* input,
796796
const T* dout,
797797
const int i1_end,
798-
const int64_t n2,
798+
const int64_t N,
799799
const T_ACC* __restrict__ mean,
800800
const T_ACC* __restrict__ rstd)
801801
{
@@ -805,9 +805,9 @@ void cuLoadWriteStridedInputs(
805805
T curr_rstd = rstd[i1];
806806
for (int k = 0; k < blockDim.y; ++k) {
807807
int i2 = i2_off + k;
808-
int load_idx = i1*n2+i2;
808+
int load_idx = i1*N+i2;
809809
int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
810-
if (i2<n2) {
810+
if (i2<N) {
811811
T curr_input = static_cast<T>(input[load_idx]);
812812
T curr_dout = static_cast<T>(dout[load_idx]);
813813
warp_buf1[write_idx] = curr_dout;
@@ -838,7 +838,7 @@ void cuLoadAddStridedInputs(
838838
const T* input,
839839
const T* dout,
840840
const int i1_end,
841-
const int64_t n2,
841+
const int64_t N,
842842
const T_ACC* __restrict__ mean,
843843
const T_ACC* __restrict__ rstd)
844844
{
@@ -848,9 +848,9 @@ void cuLoadAddStridedInputs(
848848
T_ACC curr_rstd = rstd[i1];
849849
for (int k = 0; k < blockDim.y; ++k) {
850850
int i2 = i2_off + k;
851-
int load_idx = i1*n2+i2;
851+
int load_idx = i1*N+i2;
852852
int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
853-
if (i2<n2) {
853+
if (i2<N) {
854854
T_ACC curr_input = static_cast<T_ACC>(input[load_idx]);
855855
T_ACC curr_dout = static_cast<T_ACC>(dout[load_idx]);
856856
warp_buf1[write_idx] += curr_dout;
@@ -864,18 +864,18 @@ template<typename T, typename T_ACC> __global__
864864
void cuComputePartGradGammaBeta(
865865
const T* __restrict__ dout,
866866
const T* __restrict__ input,
867-
const int64_t n1,
868-
const int64_t n2,
867+
const int64_t M,
868+
const int64_t N,
869869
const T_ACC* __restrict__ mean,
870870
const T_ACC* __restrict__ rstd,
871871
T_ACC* part_grad_gamma,
872872
T_ACC* part_grad_beta)
873873
{
874-
const int numsegs_n1 = (n1+blockDim.y*blockDim.y-1) / (blockDim.y*blockDim.y);
875-
const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y;
874+
const int numsegs_M = (M+blockDim.y*blockDim.y-1) / (blockDim.y*blockDim.y);
875+
const int segs_per_block = (numsegs_M + gridDim.y - 1) / gridDim.y;
876876
const int i1_beg = blockIdx.y * segs_per_block * blockDim.y*blockDim.y;
877877
const int i1_beg_plus_one = (blockIdx.y+1) * segs_per_block * blockDim.y*blockDim.y;
878-
const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1;
878+
const int i1_end = i1_beg_plus_one < M ? i1_beg_plus_one : M;
879879
const int row_stride = blockDim.x+1;
880880
const int thr_load_col_off = (threadIdx.x*blockDim.y)&(blockDim.x-1);
881881
const int thr_load_row_off = (threadIdx.x*blockDim.y)/blockDim.x + threadIdx.y*blockDim.y;
@@ -886,9 +886,9 @@ void cuComputePartGradGammaBeta(
886886
T_ACC* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride;
887887
// compute partial sums from strided inputs
888888
// do this to increase number of loads in flight
889-
cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,rstd);
889+
cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,N,mean,rstd);
890890
for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) {
891-
cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,rstd);
891+
cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,N,mean,rstd);
892892
}
893893
__syncthreads();
894894
// inter-warp reductions
@@ -917,13 +917,13 @@ void cuComputePartGradGammaBeta(
917917
__syncthreads();
918918
}
919919
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
920-
if (threadIdx.y == 0 && i2 < n2) {
920+
if (threadIdx.y == 0 && i2 < N) {
921921
int row1 = threadIdx.y;
922922
int row2 = threadIdx.y + 1;
923923
int idx1 = row1*row_stride + threadIdx.x;
924924
int idx2 = row2*row_stride + threadIdx.x;
925-
part_grad_beta[blockIdx.y*n2+i2] = warp_buf1[idx1] + warp_buf1[idx2];
926-
part_grad_gamma[blockIdx.y*n2+i2] = warp_buf2[idx1] + warp_buf2[idx2];
925+
part_grad_beta[blockIdx.y*N+i2] = warp_buf1[idx1] + warp_buf1[idx2];
926+
part_grad_gamma[blockIdx.y*N+i2] = warp_buf2[idx1] + warp_buf2[idx2];
927927
}
928928
}
929929

@@ -932,25 +932,25 @@ void cuComputeGradGammaBeta(
932932
const T_ACC* part_grad_gamma,
933933
const T_ACC* part_grad_beta,
934934
const int part_size,
935-
const int64_t n1,
936-
const int64_t n2,
935+
const int64_t M,
936+
const int64_t N,
937937
T* grad_gamma,
938938
T* grad_beta)
939939
{
940940
// sum partial gradients for gamma and beta
941941
SharedMemory<T_ACC> shared;
942942
T_ACC* buf = shared.getPointer();
943943
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
944-
if (i2 < n2) {
944+
if (i2 < N) {
945945
// each warp does sequential reductions until reduced part_size is num_warps
946946
int num_warp_reductions = part_size / blockDim.y;
947947
T_ACC sum_gamma = T_ACC(0);
948948
T_ACC sum_beta = T_ACC(0);
949-
const T_ACC* part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;
950-
const T_ACC* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;
949+
const T_ACC* part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * N + i2;
950+
const T_ACC* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * N + i2;
951951
for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) {
952-
sum_gamma += part_grad_gamma_ptr[warp_offset*n2];
953-
sum_beta += part_grad_beta_ptr[warp_offset*n2];
952+
sum_gamma += part_grad_gamma_ptr[warp_offset*N];
953+
sum_beta += part_grad_beta_ptr[warp_offset*N];
954954
}
955955
// inter-warp reductions
956956
const int nbsize3 = blockDim.x * blockDim.y / 2;
@@ -982,37 +982,37 @@ template<typename T, typename T_ACC> __global__
982982
void cuComputeGradInput(
983983
const T* __restrict__ dout,
984984
const T* __restrict__ input,
985-
const int64_t n1,
986-
const int64_t n2,
985+
const int64_t M,
986+
const int64_t N,
987987
const T_ACC* __restrict__ mean,
988988
const T_ACC* __restrict__ rstd,
989989
const T* gamma,
990990
T* grad_input)
991991
{
992-
for (int i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
992+
for (int i1=blockIdx.y; i1 < M; i1 += gridDim.y) {
993993
T_ACC sum_loss1 = T_ACC(0);
994994
T_ACC sum_loss2 = T_ACC(0);
995995
T_ACC c_mean = mean[i1];
996996
const T_ACC c_rstd = rstd[i1];
997-
const T* k_input = input + i1*n2;
998-
const T* k_dout = dout + i1*n2;
997+
const T* k_input = input + i1*N;
998+
const T* k_dout = dout + i1*N;
999999
const int numx = blockDim.x * blockDim.y;
10001000
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
10011001
if (gamma != NULL) {
10021002
// Optimization for ROCm MI100
1003-
for( int l = 0; l < n2 ; l += numx) {
1003+
for( int l = 0; l < N ; l += numx) {
10041004
int idx = l + thrx;
1005-
const T_ACC gamma_idx = static_cast<T_ACC>((idx<n2) ? gamma[idx] : T(0));
1006-
const T_ACC c_h = static_cast<T_ACC>((idx<n2) ? k_input[idx] : T(0));
1007-
const T_ACC c_loss = static_cast<T_ACC>((idx<n2) ? k_dout[idx] : T(0));
1005+
const T_ACC gamma_idx = static_cast<T_ACC>((idx<N) ? gamma[idx] : T(0));
1006+
const T_ACC c_h = static_cast<T_ACC>((idx<N) ? k_input[idx] : T(0));
1007+
const T_ACC c_loss = static_cast<T_ACC>((idx<N) ? k_dout[idx] : T(0));
10081008
sum_loss1 += c_loss * gamma_idx;
10091009
sum_loss2 += c_loss * gamma_idx * (c_h - c_mean) * c_rstd;
10101010
}
10111011
} else {
1012-
for( int l = 0; l < n2 ; l += numx) {
1012+
for( int l = 0; l < N ; l += numx) {
10131013
int idx = l + thrx;
1014-
const T_ACC c_h = static_cast<T_ACC>((idx<n2) ? k_input[idx] : T(0));
1015-
const T_ACC c_loss = static_cast<T_ACC>((idx<n2) ? k_dout[idx] : T(0));
1014+
const T_ACC c_h = static_cast<T_ACC>((idx<N) ? k_input[idx] : T(0));
1015+
const T_ACC c_loss = static_cast<T_ACC>((idx<N) ? k_dout[idx] : T(0));
10161016
sum_loss1 += c_loss;
10171017
sum_loss2 += c_loss * (c_h - c_mean) * c_rstd;
10181018
}
@@ -1053,11 +1053,11 @@ void cuComputeGradInput(
10531053
}
10541054
}
10551055
// all threads now have the two sums over l
1056-
T_ACC fH = (T_ACC)n2;
1056+
T_ACC fH = (T_ACC)N;
10571057
T_ACC term1 = (T_ACC(1) / fH) * c_rstd;
1058-
T* k_grad_input = grad_input + i1*n2;
1058+
T* k_grad_input = grad_input + i1*N;
10591059
if (gamma != NULL) {
1060-
for (int l = thrx; l < n2; l+=numx) {
1060+
for (int l = thrx; l < N; l+=numx) {
10611061
const T_ACC c_h = static_cast<T_ACC>(k_input[l]);
10621062
const T_ACC c_loss = static_cast<T_ACC>(k_dout[l]);
10631063
T_ACC f_grad_input = fH * c_loss * gamma[l];
@@ -1067,7 +1067,7 @@ void cuComputeGradInput(
10671067
k_grad_input[l] = static_cast<T>(f_grad_input);
10681068
}
10691069
} else {
1070-
for (int l = thrx; l < n2; l+=numx) {
1070+
for (int l = thrx; l < N; l+=numx) {
10711071
const T_ACC c_h = static_cast<T_ACC>(k_input[l]);
10721072
const T_ACC c_loss = static_cast<T_ACC>(k_dout[l]);
10731073
T_ACC f_grad_input = fH * c_loss;

0 commit comments

Comments
 (0)