2525#endif
2626
2727#include < c10/cuda/CUDAMathCompat.h>
28+ #include < c10/util/env.h>
2829
2930namespace {
3031// This is the un-specialized struct. Note that we prevent instantiation of this
@@ -796,7 +797,7 @@ void cuLoadWriteStridedInputs(
796797 const int i1_end,
797798 const int64_t n2,
798799 const T_ACC* __restrict__ mean,
799- const T_ACC* __restrict__ invvar ,
800+ const T_ACC* __restrict__ rstd ,
800801 bool rms_only
801802 )
802803{
@@ -806,7 +807,7 @@ void cuLoadWriteStridedInputs(
806807 if (!rms_only) {
807808 curr_mean = mean[i1];
808809 }
809- T curr_invvar = invvar [i1];
810+ T curr_rstd = rstd [i1];
810811 for (int k = 0 ; k < blockDim .y ; ++k) {
811812 int i2 = i2_off + k;
812813 int load_idx = i1*n2+i2;
@@ -816,9 +817,9 @@ void cuLoadWriteStridedInputs(
816817 T curr_dout = static_cast <T>(dout[load_idx]);
817818 if (!rms_only) {
818819 warp_buf1[write_idx] = curr_dout;
819- warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar ;
820+ warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_rstd ;
820821 } else {
821- warp_buf2[write_idx] = curr_dout * (curr_input) * curr_invvar ;
822+ warp_buf2[write_idx] = curr_dout * (curr_input) * curr_rstd ;
822823 }
823824 } else {
824825 if (!rms_only) {
@@ -852,7 +853,7 @@ void cuLoadAddStridedInputs(
852853 const int i1_end,
853854 const int64_t n2,
854855 const T_ACC* __restrict__ mean,
855- const T_ACC* __restrict__ invvar ,
856+ const T_ACC* __restrict__ rstd ,
856857 bool rms_only
857858 )
858859{
@@ -862,7 +863,7 @@ void cuLoadAddStridedInputs(
862863 if (!rms_only) {
863864 curr_mean = mean[i1];
864865 }
865- T_ACC curr_invvar = invvar [i1];
866+ T_ACC curr_rstd = rstd [i1];
866867 for (int k = 0 ; k < blockDim .y ; ++k) {
867868 int i2 = i2_off + k;
868869 int load_idx = i1*n2+i2;
@@ -872,9 +873,9 @@ void cuLoadAddStridedInputs(
872873 T_ACC curr_dout = static_cast <T_ACC>(dout[load_idx]);
873874 if (!rms_only) {
874875 warp_buf1[write_idx] += curr_dout;
875- warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar ;
876+ warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_rstd ;
876877 } else {
877- warp_buf2[write_idx] += curr_dout * (curr_input) * curr_invvar ;
878+ warp_buf2[write_idx] += curr_dout * (curr_input) * curr_rstd ;
878879 }
879880 }
880881 }
@@ -888,7 +889,7 @@ void cuComputePartGradGammaBeta(
888889 const int64_t n1,
889890 const int64_t n2,
890891 const T_ACC* __restrict__ mean,
891- const T_ACC* __restrict__ invvar ,
892+ const T_ACC* __restrict__ rstd ,
892893 T_ACC* part_grad_gamma,
893894 T_ACC* part_grad_beta,
894895 bool rms_only)
@@ -908,9 +909,9 @@ void cuComputePartGradGammaBeta(
908909 T_ACC* warp_buf2 = warp_buf1 + blockDim .y * blockDim .y * row_stride;
909910 // compute partial sums from strided inputs
910911 // do this to increase number of loads in flight
911- cuLoadWriteStridedInputs (i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar , rms_only);
912+ cuLoadWriteStridedInputs (i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,rstd , rms_only);
912913 for (int i1_block = i1_beg+blockDim .y *blockDim .y ; i1_block < i1_end; i1_block+=blockDim .y *blockDim .y ) {
913- cuLoadAddStridedInputs (i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar , rms_only);
914+ cuLoadAddStridedInputs (i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,rstd , rms_only);
914915 }
915916 __syncthreads ();
916917 // inter-warp reductions
@@ -1024,7 +1025,7 @@ void cuComputeGradInput(
10241025 const int64_t n1,
10251026 const int64_t n2,
10261027 const T_ACC* __restrict__ mean,
1027- const T_ACC* __restrict__ invvar ,
1028+ const T_ACC* __restrict__ rstd ,
10281029 const T* gamma,
10291030 T* grad_input,
10301031 bool rms_only)
@@ -1036,7 +1037,7 @@ void cuComputeGradInput(
10361037 if (!rms_only) {
10371038 c_mean = mean[i1];
10381039 }
1039- const T_ACC c_invvar = invvar [i1];
1040+ const T_ACC c_rstd = rstd [i1];
10401041 const T* k_input = input + i1*n2;
10411042 const T* k_dout = dout + i1*n2;
10421043 const int numx = blockDim .x * blockDim .y ;
@@ -1050,9 +1051,9 @@ void cuComputeGradInput(
10501051 const T_ACC c_loss = static_cast <T_ACC>((idx<n2) ? k_dout[idx] : T (0 ));
10511052 if (!rms_only) {
10521053 sum_loss1 += c_loss * gamma_idx;
1053- sum_loss2 += c_loss * gamma_idx * (c_h - c_mean) * c_invvar ;
1054+ sum_loss2 += c_loss * gamma_idx * (c_h - c_mean) * c_rstd ;
10541055 } else {
1055- sum_loss2 += c_loss * gamma_idx * (c_h) * c_invvar ;
1056+ sum_loss2 += c_loss * gamma_idx * (c_h) * c_rstd ;
10561057 }
10571058 }
10581059 } else {
@@ -1062,9 +1063,9 @@ void cuComputeGradInput(
10621063 const T_ACC c_loss = static_cast <T_ACC>((idx<n2) ? k_dout[idx] : T (0 ));
10631064 if (!rms_only) {
10641065 sum_loss1 += c_loss;
1065- sum_loss2 += c_loss * (c_h - c_mean) * c_invvar ;
1066+ sum_loss2 += c_loss * (c_h - c_mean) * c_rstd ;
10661067 } else {
1067- sum_loss2 += c_loss * (c_h) * c_invvar ;
1068+ sum_loss2 += c_loss * (c_h) * c_rstd ;
10681069 }
10691070 }
10701071 }
@@ -1115,7 +1116,7 @@ void cuComputeGradInput(
11151116 }
11161117 // all threads now have the two sums over l
11171118 T_ACC fH = (T_ACC)n2;
1118- T_ACC term1 = (T_ACC (1 ) / fH ) * c_invvar ;
1119+ T_ACC term1 = (T_ACC (1 ) / fH ) * c_rstd ;
11191120 T* k_grad_input = grad_input + i1*n2;
11201121 if (gamma != NULL ) {
11211122 for (int l = thrx; l < n2; l+=numx) {
@@ -1124,9 +1125,9 @@ void cuComputeGradInput(
11241125 T_ACC f_grad_input = fH * c_loss * gamma[l];
11251126 if (!rms_only) {
11261127 f_grad_input -= sum_loss1;
1127- f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
1128+ f_grad_input -= (c_h - c_mean) * c_rstd * sum_loss2;
11281129 } else {
1129- f_grad_input -= (c_h) * c_invvar * sum_loss2;
1130+ f_grad_input -= (c_h) * c_rstd * sum_loss2;
11301131 }
11311132 f_grad_input *= term1;
11321133 k_grad_input[l] = static_cast <T>(f_grad_input);
@@ -1138,9 +1139,9 @@ void cuComputeGradInput(
11381139 T_ACC f_grad_input = fH * c_loss;
11391140 if (!rms_only) {
11401141 f_grad_input -= sum_loss1;
1141- f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
1142+ f_grad_input -= (c_h - c_mean) * c_rstd * sum_loss2;
11421143 } else {
1143- f_grad_input -= (c_h) * c_invvar * sum_loss2;
1144+ f_grad_input -= (c_h) * c_rstd * sum_loss2;
11441145 }
11451146 f_grad_input *= term1;
11461147 k_grad_input[l] = static_cast <T>(f_grad_input);
0 commit comments