Skip to content

Commit 730cef1

Browse files
pnunna93hubertlu-tw
authored andcommitted
Add header file for c10::utils::check_env and replace invvar with rstd
1 parent a329b96 commit 730cef1

1 file changed

Lines changed: 23 additions & 22 deletions

File tree

aten/src/ATen/native/cuda/layer_norm_kernel.cu

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#endif
2626

2727
#include <c10/cuda/CUDAMathCompat.h>
28+
#include <c10/util/env.h>
2829

2930
namespace {
3031
// This is the un-specialized struct. Note that we prevent instantiation of this
@@ -796,7 +797,7 @@ void cuLoadWriteStridedInputs(
796797
const int i1_end,
797798
const int64_t n2,
798799
const T_ACC* __restrict__ mean,
799-
const T_ACC* __restrict__ invvar,
800+
const T_ACC* __restrict__ rstd,
800801
bool rms_only
801802
)
802803
{
@@ -806,7 +807,7 @@ void cuLoadWriteStridedInputs(
806807
if (!rms_only) {
807808
curr_mean = mean[i1];
808809
}
809-
T curr_invvar = invvar[i1];
810+
T curr_rstd = rstd[i1];
810811
for (int k = 0; k < blockDim.y; ++k) {
811812
int i2 = i2_off + k;
812813
int load_idx = i1*n2+i2;
@@ -816,9 +817,9 @@ void cuLoadWriteStridedInputs(
816817
T curr_dout = static_cast<T>(dout[load_idx]);
817818
if (!rms_only) {
818819
warp_buf1[write_idx] = curr_dout;
819-
warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar;
820+
warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_rstd;
820821
} else {
821-
warp_buf2[write_idx] = curr_dout * (curr_input) * curr_invvar;
822+
warp_buf2[write_idx] = curr_dout * (curr_input) * curr_rstd;
822823
}
823824
} else {
824825
if (!rms_only) {
@@ -852,7 +853,7 @@ void cuLoadAddStridedInputs(
852853
const int i1_end,
853854
const int64_t n2,
854855
const T_ACC* __restrict__ mean,
855-
const T_ACC* __restrict__ invvar,
856+
const T_ACC* __restrict__ rstd,
856857
bool rms_only
857858
)
858859
{
@@ -862,7 +863,7 @@ void cuLoadAddStridedInputs(
862863
if (!rms_only) {
863864
curr_mean = mean[i1];
864865
}
865-
T_ACC curr_invvar = invvar[i1];
866+
T_ACC curr_rstd = rstd[i1];
866867
for (int k = 0; k < blockDim.y; ++k) {
867868
int i2 = i2_off + k;
868869
int load_idx = i1*n2+i2;
@@ -872,9 +873,9 @@ void cuLoadAddStridedInputs(
872873
T_ACC curr_dout = static_cast<T_ACC>(dout[load_idx]);
873874
if (!rms_only) {
874875
warp_buf1[write_idx] += curr_dout;
875-
warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar;
876+
warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_rstd;
876877
} else {
877-
warp_buf2[write_idx] += curr_dout * (curr_input) * curr_invvar;
878+
warp_buf2[write_idx] += curr_dout * (curr_input) * curr_rstd;
878879
}
879880
}
880881
}
@@ -888,7 +889,7 @@ void cuComputePartGradGammaBeta(
888889
const int64_t n1,
889890
const int64_t n2,
890891
const T_ACC* __restrict__ mean,
891-
const T_ACC* __restrict__ invvar,
892+
const T_ACC* __restrict__ rstd,
892893
T_ACC* part_grad_gamma,
893894
T_ACC* part_grad_beta,
894895
bool rms_only)
@@ -908,9 +909,9 @@ void cuComputePartGradGammaBeta(
908909
T_ACC* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride;
909910
// compute partial sums from strided inputs
910911
// do this to increase number of loads in flight
911-
cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar, rms_only);
912+
cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,rstd, rms_only);
912913
for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) {
913-
cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar, rms_only);
914+
cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,rstd, rms_only);
914915
}
915916
__syncthreads();
916917
// inter-warp reductions
@@ -1024,7 +1025,7 @@ void cuComputeGradInput(
10241025
const int64_t n1,
10251026
const int64_t n2,
10261027
const T_ACC* __restrict__ mean,
1027-
const T_ACC* __restrict__ invvar,
1028+
const T_ACC* __restrict__ rstd,
10281029
const T* gamma,
10291030
T* grad_input,
10301031
bool rms_only)
@@ -1036,7 +1037,7 @@ void cuComputeGradInput(
10361037
if (!rms_only) {
10371038
c_mean = mean[i1];
10381039
}
1039-
const T_ACC c_invvar = invvar[i1];
1040+
const T_ACC c_rstd = rstd[i1];
10401041
const T* k_input = input + i1*n2;
10411042
const T* k_dout = dout + i1*n2;
10421043
const int numx = blockDim.x * blockDim.y;
@@ -1050,9 +1051,9 @@ void cuComputeGradInput(
10501051
const T_ACC c_loss = static_cast<T_ACC>((idx<n2) ? k_dout[idx] : T(0));
10511052
if (!rms_only) {
10521053
sum_loss1 += c_loss * gamma_idx;
1053-
sum_loss2 += c_loss * gamma_idx * (c_h - c_mean) * c_invvar;
1054+
sum_loss2 += c_loss * gamma_idx * (c_h - c_mean) * c_rstd;
10541055
} else {
1055-
sum_loss2 += c_loss * gamma_idx * (c_h) * c_invvar;
1056+
sum_loss2 += c_loss * gamma_idx * (c_h) * c_rstd;
10561057
}
10571058
}
10581059
} else {
@@ -1062,9 +1063,9 @@ void cuComputeGradInput(
10621063
const T_ACC c_loss = static_cast<T_ACC>((idx<n2) ? k_dout[idx] : T(0));
10631064
if (!rms_only) {
10641065
sum_loss1 += c_loss;
1065-
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
1066+
sum_loss2 += c_loss * (c_h - c_mean) * c_rstd;
10661067
} else {
1067-
sum_loss2 += c_loss * (c_h) * c_invvar;
1068+
sum_loss2 += c_loss * (c_h) * c_rstd;
10681069
}
10691070
}
10701071
}
@@ -1115,7 +1116,7 @@ void cuComputeGradInput(
11151116
}
11161117
// all threads now have the two sums over l
11171118
T_ACC fH = (T_ACC)n2;
1118-
T_ACC term1 = (T_ACC(1) / fH) * c_invvar;
1119+
T_ACC term1 = (T_ACC(1) / fH) * c_rstd;
11191120
T* k_grad_input = grad_input + i1*n2;
11201121
if (gamma != NULL) {
11211122
for (int l = thrx; l < n2; l+=numx) {
@@ -1124,9 +1125,9 @@ void cuComputeGradInput(
11241125
T_ACC f_grad_input = fH * c_loss * gamma[l];
11251126
if (!rms_only) {
11261127
f_grad_input -= sum_loss1;
1127-
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
1128+
f_grad_input -= (c_h - c_mean) * c_rstd * sum_loss2;
11281129
} else {
1129-
f_grad_input -= (c_h) * c_invvar * sum_loss2;
1130+
f_grad_input -= (c_h) * c_rstd * sum_loss2;
11301131
}
11311132
f_grad_input *= term1;
11321133
k_grad_input[l] = static_cast<T>(f_grad_input);
@@ -1138,9 +1139,9 @@ void cuComputeGradInput(
11381139
T_ACC f_grad_input = fH * c_loss;
11391140
if (!rms_only) {
11401141
f_grad_input -= sum_loss1;
1141-
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
1142+
f_grad_input -= (c_h - c_mean) * c_rstd * sum_loss2;
11421143
} else {
1143-
f_grad_input -= (c_h) * c_invvar * sum_loss2;
1144+
f_grad_input -= (c_h) * c_rstd * sum_loss2;
11441145
}
11451146
f_grad_input *= term1;
11461147
k_grad_input[l] = static_cast<T>(f_grad_input);

0 commit comments

Comments
 (0)