@@ -795,7 +795,7 @@ void cuLoadWriteStridedInputs(
795795 const T* input,
796796 const T* dout,
797797 const int i1_end,
798- const int64_t n2 ,
798+ const int64_t N ,
799799 const T_ACC* __restrict__ mean,
800800 const T_ACC* __restrict__ rstd)
801801{
@@ -805,9 +805,9 @@ void cuLoadWriteStridedInputs(
805805 T curr_rstd = rstd[i1];
806806 for (int k = 0 ; k < blockDim .y ; ++k) {
807807 int i2 = i2_off + k;
808- int load_idx = i1*n2 +i2;
808+ int load_idx = i1*N +i2;
809809 int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
810- if (i2<n2 ) {
810+ if (i2<N ) {
811811 T curr_input = static_cast <T>(input[load_idx]);
812812 T curr_dout = static_cast <T>(dout[load_idx]);
813813 warp_buf1[write_idx] = curr_dout;
@@ -838,7 +838,7 @@ void cuLoadAddStridedInputs(
838838 const T* input,
839839 const T* dout,
840840 const int i1_end,
841- const int64_t n2 ,
841+ const int64_t N ,
842842 const T_ACC* __restrict__ mean,
843843 const T_ACC* __restrict__ rstd)
844844{
@@ -848,9 +848,9 @@ void cuLoadAddStridedInputs(
848848 T_ACC curr_rstd = rstd[i1];
849849 for (int k = 0 ; k < blockDim .y ; ++k) {
850850 int i2 = i2_off + k;
851- int load_idx = i1*n2 +i2;
851+ int load_idx = i1*N +i2;
852852 int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
853- if (i2<n2 ) {
853+ if (i2<N ) {
854854 T_ACC curr_input = static_cast <T_ACC>(input[load_idx]);
855855 T_ACC curr_dout = static_cast <T_ACC>(dout[load_idx]);
856856 warp_buf1[write_idx] += curr_dout;
@@ -864,18 +864,18 @@ template<typename T, typename T_ACC> __global__
864864void cuComputePartGradGammaBeta (
865865 const T* __restrict__ dout,
866866 const T* __restrict__ input,
867- const int64_t n1 ,
868- const int64_t n2 ,
867+ const int64_t M ,
868+ const int64_t N ,
869869 const T_ACC* __restrict__ mean,
870870 const T_ACC* __restrict__ rstd,
871871 T_ACC* part_grad_gamma,
872872 T_ACC* part_grad_beta)
873873{
874- const int numsegs_n1 = (n1 +blockDim .y *blockDim .y -1 ) / (blockDim .y *blockDim .y );
875- const int segs_per_block = (numsegs_n1 + gridDim .y - 1 ) / gridDim .y ;
874+ const int numsegs_M = (M +blockDim .y *blockDim .y -1 ) / (blockDim .y *blockDim .y );
875+ const int segs_per_block = (numsegs_M + gridDim .y - 1 ) / gridDim .y ;
876876 const int i1_beg = blockIdx .y * segs_per_block * blockDim .y *blockDim .y ;
877877 const int i1_beg_plus_one = (blockIdx .y +1 ) * segs_per_block * blockDim .y *blockDim .y ;
878- const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1 ;
878+ const int i1_end = i1_beg_plus_one < M ? i1_beg_plus_one : M ;
879879 const int row_stride = blockDim .x +1 ;
880880 const int thr_load_col_off = (threadIdx .x *blockDim .y )&(blockDim .x -1 );
881881 const int thr_load_row_off = (threadIdx .x *blockDim .y )/blockDim .x + threadIdx .y *blockDim .y ;
@@ -886,9 +886,9 @@ void cuComputePartGradGammaBeta(
886886 T_ACC* warp_buf2 = warp_buf1 + blockDim .y * blockDim .y * row_stride;
887887 // compute partial sums from strided inputs
888888 // do this to increase number of loads in flight
889- cuLoadWriteStridedInputs (i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2 ,mean,rstd);
889+ cuLoadWriteStridedInputs (i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,N ,mean,rstd);
890890 for (int i1_block = i1_beg+blockDim .y *blockDim .y ; i1_block < i1_end; i1_block+=blockDim .y *blockDim .y ) {
891- cuLoadAddStridedInputs (i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2 ,mean,rstd);
891+ cuLoadAddStridedInputs (i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,N ,mean,rstd);
892892 }
893893 __syncthreads ();
894894 // inter-warp reductions
@@ -917,13 +917,13 @@ void cuComputePartGradGammaBeta(
917917 __syncthreads ();
918918 }
919919 int i2 = blockIdx .x * blockDim .x + threadIdx .x ;
920- if (threadIdx .y == 0 && i2 < n2 ) {
920+ if (threadIdx .y == 0 && i2 < N ) {
921921 int row1 = threadIdx .y ;
922922 int row2 = threadIdx .y + 1 ;
923923 int idx1 = row1*row_stride + threadIdx .x ;
924924 int idx2 = row2*row_stride + threadIdx .x ;
925- part_grad_beta[blockIdx .y *n2 +i2] = warp_buf1[idx1] + warp_buf1[idx2];
926- part_grad_gamma[blockIdx .y *n2 +i2] = warp_buf2[idx1] + warp_buf2[idx2];
925+ part_grad_beta[blockIdx .y *N +i2] = warp_buf1[idx1] + warp_buf1[idx2];
926+ part_grad_gamma[blockIdx .y *N +i2] = warp_buf2[idx1] + warp_buf2[idx2];
927927 }
928928}
929929
@@ -932,25 +932,25 @@ void cuComputeGradGammaBeta(
932932 const T_ACC* part_grad_gamma,
933933 const T_ACC* part_grad_beta,
934934 const int part_size,
935- const int64_t n1 ,
936- const int64_t n2 ,
935+ const int64_t M ,
936+ const int64_t N ,
937937 T* grad_gamma,
938938 T* grad_beta)
939939{
940940 // sum partial gradients for gamma and beta
941941 SharedMemory<T_ACC> shared;
942942 T_ACC* buf = shared.getPointer ();
943943 int i2 = blockIdx .x * blockDim .x + threadIdx .x ;
944- if (i2 < n2 ) {
944+ if (i2 < N ) {
945945 // each warp does sequential reductions until reduced part_size is num_warps
946946 int num_warp_reductions = part_size / blockDim .y ;
947947 T_ACC sum_gamma = T_ACC (0 );
948948 T_ACC sum_beta = T_ACC (0 );
949- const T_ACC* part_grad_gamma_ptr = part_grad_gamma + threadIdx .y * num_warp_reductions * n2 + i2;
950- const T_ACC* part_grad_beta_ptr = part_grad_beta + threadIdx .y * num_warp_reductions * n2 + i2;
949+ const T_ACC* part_grad_gamma_ptr = part_grad_gamma + threadIdx .y * num_warp_reductions * N + i2;
950+ const T_ACC* part_grad_beta_ptr = part_grad_beta + threadIdx .y * num_warp_reductions * N + i2;
951951 for (int warp_offset = 0 ; warp_offset < num_warp_reductions; ++warp_offset) {
952- sum_gamma += part_grad_gamma_ptr[warp_offset*n2 ];
953- sum_beta += part_grad_beta_ptr[warp_offset*n2 ];
952+ sum_gamma += part_grad_gamma_ptr[warp_offset*N ];
953+ sum_beta += part_grad_beta_ptr[warp_offset*N ];
954954 }
955955 // inter-warp reductions
956956 const int nbsize3 = blockDim .x * blockDim .y / 2 ;
@@ -982,37 +982,37 @@ template<typename T, typename T_ACC> __global__
982982void cuComputeGradInput (
983983 const T* __restrict__ dout,
984984 const T* __restrict__ input,
985- const int64_t n1 ,
986- const int64_t n2 ,
985+ const int64_t M ,
986+ const int64_t N ,
987987 const T_ACC* __restrict__ mean,
988988 const T_ACC* __restrict__ rstd,
989989 const T* gamma,
990990 T* grad_input)
991991{
992- for (int i1=blockIdx .y ; i1 < n1 ; i1 += gridDim .y ) {
992+ for (int i1=blockIdx .y ; i1 < M ; i1 += gridDim .y ) {
993993 T_ACC sum_loss1 = T_ACC (0 );
994994 T_ACC sum_loss2 = T_ACC (0 );
995995 T_ACC c_mean = mean[i1];
996996 const T_ACC c_rstd = rstd[i1];
997- const T* k_input = input + i1*n2 ;
998- const T* k_dout = dout + i1*n2 ;
997+ const T* k_input = input + i1*N ;
998+ const T* k_dout = dout + i1*N ;
999999 const int numx = blockDim .x * blockDim .y ;
10001000 const int thrx = threadIdx .x + threadIdx .y * blockDim .x ;
10011001 if (gamma != NULL ) {
10021002 // Optimization for ROCm MI100
1003- for ( int l = 0 ; l < n2 ; l += numx) {
1003+ for ( int l = 0 ; l < N ; l += numx) {
10041004 int idx = l + thrx;
1005- const T_ACC gamma_idx = static_cast <T_ACC>((idx<n2 ) ? gamma[idx] : T (0 ));
1006- const T_ACC c_h = static_cast <T_ACC>((idx<n2 ) ? k_input[idx] : T (0 ));
1007- const T_ACC c_loss = static_cast <T_ACC>((idx<n2 ) ? k_dout[idx] : T (0 ));
1005+ const T_ACC gamma_idx = static_cast <T_ACC>((idx<N ) ? gamma[idx] : T (0 ));
1006+ const T_ACC c_h = static_cast <T_ACC>((idx<N ) ? k_input[idx] : T (0 ));
1007+ const T_ACC c_loss = static_cast <T_ACC>((idx<N ) ? k_dout[idx] : T (0 ));
10081008 sum_loss1 += c_loss * gamma_idx;
10091009 sum_loss2 += c_loss * gamma_idx * (c_h - c_mean) * c_rstd;
10101010 }
10111011 } else {
1012- for ( int l = 0 ; l < n2 ; l += numx) {
1012+ for ( int l = 0 ; l < N ; l += numx) {
10131013 int idx = l + thrx;
1014- const T_ACC c_h = static_cast <T_ACC>((idx<n2 ) ? k_input[idx] : T (0 ));
1015- const T_ACC c_loss = static_cast <T_ACC>((idx<n2 ) ? k_dout[idx] : T (0 ));
1014+ const T_ACC c_h = static_cast <T_ACC>((idx<N ) ? k_input[idx] : T (0 ));
1015+ const T_ACC c_loss = static_cast <T_ACC>((idx<N ) ? k_dout[idx] : T (0 ));
10161016 sum_loss1 += c_loss;
10171017 sum_loss2 += c_loss * (c_h - c_mean) * c_rstd;
10181018 }
@@ -1053,11 +1053,11 @@ void cuComputeGradInput(
10531053 }
10541054 }
10551055 // all threads now have the two sums over l
1056- T_ACC fH = (T_ACC)n2 ;
1056+ T_ACC fH = (T_ACC)N ;
10571057 T_ACC term1 = (T_ACC (1 ) / fH ) * c_rstd;
1058- T* k_grad_input = grad_input + i1*n2 ;
1058+ T* k_grad_input = grad_input + i1*N ;
10591059 if (gamma != NULL ) {
1060- for (int l = thrx; l < n2 ; l+=numx) {
1060+ for (int l = thrx; l < N ; l+=numx) {
10611061 const T_ACC c_h = static_cast <T_ACC>(k_input[l]);
10621062 const T_ACC c_loss = static_cast <T_ACC>(k_dout[l]);
10631063 T_ACC f_grad_input = fH * c_loss * gamma[l];
@@ -1067,7 +1067,7 @@ void cuComputeGradInput(
10671067 k_grad_input[l] = static_cast <T>(f_grad_input);
10681068 }
10691069 } else {
1070- for (int l = thrx; l < n2 ; l+=numx) {
1070+ for (int l = thrx; l < N ; l+=numx) {
10711071 const T_ACC c_h = static_cast <T_ACC>(k_input[l]);
10721072 const T_ACC c_loss = static_cast <T_ACC>(k_dout[l]);
10731073 T_ACC f_grad_input = fH * c_loss;
0 commit comments