@@ -9,67 +9,67 @@ namespace cv {
99namespace opt_AVX2
1010{
1111#if CV_TRY_AVX2
12- void convBlock_AVX2 (int k, const float *a, const float *b,
13- float *c, int ldc, const float *bias,
14- float minval, float maxval, bool ifActiv)
12+ void convBlock_AVX2 (int np, const float * a, const float * b, float * c, int ldc, bool init_c)
1513{
16- #if FAST_CONV_MR == 4 && FAST_CONV_NR == 24
17- __m256 vminval = _mm256_set1_ps (minval), vmaxval = _mm256_set1_ps (maxval);
18- __m256 c0 = _mm256_set1_ps (bias[0 ]), c1 = c0, c2 = c0;
19- __m256 c3 = _mm256_set1_ps (bias[1 ]), c4 = c3, c5 = c3;
20- __m256 c6 = _mm256_set1_ps (bias[2 ]), c7 = c6, c8 = c6;
21- __m256 c9 = _mm256_set1_ps (bias[3 ]), c10 = c9, c11 = c9;
14+ #if CONV_MR == 4 && CONV_NR == 24
15+ __m256 c00 = _mm256_set1_ps (0 .f ), c01 = c00, c02 = c00;
16+ __m256 c10 = c00, c11 = c00, c12 = c00;
17+ __m256 c20 = c00, c21 = c00, c22 = c00;
18+ __m256 c30 = c00, c31 = c00, c32 = c00;
2219
2320 __m256 a0 = _mm256_setzero_ps (), a1 = _mm256_setzero_ps ();
2421 __m256 b0 = _mm256_setzero_ps (), b1 = _mm256_setzero_ps (), b2 = _mm256_setzero_ps ();
2522
26- for (int p = 0 ; p < k ; p++, a += FAST_CONV_MR , b += FAST_CONV_NR )
23+ for (int p = 0 ; p < np ; p++, a += CONV_MR , b += CONV_NR )
2724 {
2825 a0 = _mm256_set1_ps (a[0 ]), a1 = _mm256_set1_ps (a[1 ]);
2926 b0 = _mm256_load_ps (b), b1 = _mm256_load_ps (b + 8 ), b2 = _mm256_load_ps (b + 16 );
3027
31- c0 = _mm256_fmadd_ps (b0, a0, c0 );
32- c1 = _mm256_fmadd_ps (b1, a0, c1 );
33- c2 = _mm256_fmadd_ps (b2, a0, c2 );
28+ c00 = _mm256_fmadd_ps (b0, a0, c00 );
29+ c01 = _mm256_fmadd_ps (b1, a0, c01 );
30+ c02 = _mm256_fmadd_ps (b2, a0, c02 );
3431
35- c3 = _mm256_fmadd_ps (b0, a1, c3);
36- a0 = _mm256_set1_ps (a[2 ]);
37- c4 = _mm256_fmadd_ps (b1, a1, c4);
38- c5 = _mm256_fmadd_ps (b2, a1, c5);
32+ c10 = _mm256_fmadd_ps (b0, a1, c10);
33+ c11 = _mm256_fmadd_ps (b1, a1, c11);
34+ c12 = _mm256_fmadd_ps (b2, a1, c12);
3935
40- c6 = _mm256_fmadd_ps (b0, a0, c6);
41- a1 = _mm256_set1_ps (a[3 ]);
42- c7 = _mm256_fmadd_ps (b1, a0, c7);
43- c8 = _mm256_fmadd_ps (b2, a0, c8);
36+ a0 = _mm256_set1_ps (a[2 ]), a1 = _mm256_set1_ps (a[3 ]);
4437
45- c9 = _mm256_fmadd_ps (b0, a1, c9);
46- c10 = _mm256_fmadd_ps (b1, a1, c10);
47- c11 = _mm256_fmadd_ps (b2, a1, c11);
38+ c20 = _mm256_fmadd_ps (b0, a0, c20);
39+ c21 = _mm256_fmadd_ps (b1, a0, c21);
40+ c22 = _mm256_fmadd_ps (b2, a0, c22);
41+
42+ c30 = _mm256_fmadd_ps (b0, a1, c30);
43+ c31 = _mm256_fmadd_ps (b1, a1, c31);
44+ c32 = _mm256_fmadd_ps (b2, a1, c32);
4845 }
4946
50- if (ifActiv )
47+ if (!init_c )
5148 {
52- c0 = _mm256_min_ps (_mm256_max_ps (c0, vminval), vmaxval);
53- c1 = _mm256_min_ps (_mm256_max_ps (c1, vminval), vmaxval);
54- c2 = _mm256_min_ps (_mm256_max_ps (c2, vminval), vmaxval);
55- c3 = _mm256_min_ps (_mm256_max_ps (c3, vminval), vmaxval);
56- c4 = _mm256_min_ps (_mm256_max_ps (c4, vminval), vmaxval);
57- c5 = _mm256_min_ps (_mm256_max_ps (c5, vminval), vmaxval);
58- c6 = _mm256_min_ps (_mm256_max_ps (c6, vminval), vmaxval);
59- c7 = _mm256_min_ps (_mm256_max_ps (c7, vminval), vmaxval);
60- c8 = _mm256_min_ps (_mm256_max_ps (c8, vminval), vmaxval);
61- c9 = _mm256_min_ps (_mm256_max_ps (c9, vminval), vmaxval);
62- c10 = _mm256_min_ps (_mm256_max_ps (c10, vminval), vmaxval);
63- c11 = _mm256_min_ps (_mm256_max_ps (c11, vminval), vmaxval);
49+ c00 = _mm256_add_ps (c00, _mm256_load_ps (c));
50+ c01 = _mm256_add_ps (c01, _mm256_load_ps (c + 8 ));
51+ c02 = _mm256_add_ps (c02, _mm256_load_ps (c + 16 ));
52+
53+ c10 = _mm256_add_ps (c10, _mm256_load_ps (c + ldc));
54+ c11 = _mm256_add_ps (c11, _mm256_load_ps (c + ldc + 8 ));
55+ c12 = _mm256_add_ps (c12, _mm256_load_ps (c + ldc + 16 ));
56+
57+ c20 = _mm256_add_ps (c20, _mm256_load_ps (c + ldc*2 ));
58+ c21 = _mm256_add_ps (c21, _mm256_load_ps (c + ldc*2 + 8 ));
59+ c22 = _mm256_add_ps (c22, _mm256_load_ps (c + ldc*2 + 16 ));
60+
61+ c30 = _mm256_add_ps (c30, _mm256_load_ps (c + ldc*3 ));
62+ c31 = _mm256_add_ps (c31, _mm256_load_ps (c + ldc*3 + 8 ));
63+ c32 = _mm256_add_ps (c32, _mm256_load_ps (c + ldc*3 + 16 ));
6464 }
6565
66- _mm256_storeu_ps (c, c0); _mm256_storeu_ps (c+8 , c1); _mm256_storeu_ps (c+16 , c2 );
67- _mm256_storeu_ps (c + ldc, c3); _mm256_storeu_ps (c + ldc + 8 , c4); _mm256_storeu_ps (c + ldc + 16 , c5 );
68- _mm256_storeu_ps (c + ldc*2 , c6); _mm256_storeu_ps (c + ldc*2 + 8 , c7); _mm256_storeu_ps (c + ldc*2 + 16 , c8 );
69- _mm256_storeu_ps (c + ldc*3 , c9); _mm256_storeu_ps (c + ldc*3 + 8 , c10); _mm256_storeu_ps (c + ldc*3 + 16 , c11 );
66+ _mm256_storeu_ps (c, c00), _mm256_storeu_ps (c+8 , c01), _mm256_storeu_ps (c+16 , c02 );
67+ _mm256_storeu_ps (c + ldc, c10), _mm256_storeu_ps (c + ldc + 8 , c11), _mm256_storeu_ps (c + ldc + 16 , c12 );
68+ _mm256_storeu_ps (c + ldc*2 , c20), _mm256_storeu_ps (c + ldc*2 + 8 , c21), _mm256_storeu_ps (c + ldc*2 + 16 , c22 );
69+ _mm256_storeu_ps (c + ldc*3 , c30), _mm256_storeu_ps (c + ldc*3 + 8 , c31), _mm256_storeu_ps (c + ldc*3 + 16 , c32 );
7070 _mm256_zeroupper ();
7171#else
72- #error "unsupported FAST_CONV_MR and/or FAST_CONV_NR in convBlock_AVX2."
72+ #error "unsupported CONV_MR and/or CONV_NR in convBlock_AVX2."
7373#endif
7474}
7575
@@ -78,7 +78,6 @@ void depthWiseBlock_AVX2(const float *inptr, float *outptr, const float *weights
7878 int dilation_y, int stride_x, int stride_y, int inner_xleft, int inner_xright, int inner_ytop,
7979 int inner_ybottom, bool ifMinMaxAct, bool useSIMD, bool is3x3)
8080{
81- const int VECSZ = 8 ;
8281 __m256 vminval = _mm256_set1_ps (minval);
8382 __m256 vmaxval = _mm256_set1_ps (maxval);
8483
@@ -175,7 +174,7 @@ void depthWiseBlock_AVX2(const float *inptr, float *outptr, const float *weights
175174 {
176175 if (dy0 == 3 )
177176 {
178- for (; x0 <= x1 - VECSZ ; x0 += VECSZ )
177+ for (; x0 <= x1 - FAST_VEC_NLANES ; x0 += FAST_VEC_NLANES )
179178 {
180179 int xi_ = x0 * stride_x - pad_left;
181180 const float *inptr_xi = inptr + Wi * yi_ + xi_;
@@ -251,7 +250,7 @@ void depthWiseBlock_AVX2(const float *inptr, float *outptr, const float *weights
251250 }
252251 else
253252 {
254- for (; x0 <= x1 - VECSZ ; x0 += VECSZ )
253+ for (; x0 <= x1 - FAST_VEC_NLANES ; x0 += FAST_VEC_NLANES )
255254 {
256255 int xi_ = x0 * stride_x - pad_left;
257256 const float *inptr_xi = inptr + Wi * yi_ + xi_;
@@ -277,7 +276,7 @@ void depthWiseBlock_AVX2(const float *inptr, float *outptr, const float *weights
277276 }
278277 else
279278 {
280- for (; x0 <= x1 - VECSZ ; x0 += VECSZ )
279+ for (; x0 <= x1 - FAST_VEC_NLANES ; x0 += FAST_VEC_NLANES )
281280 {
282281 int xi_ = x0 * stride_x - pad_left, k = 0 ;
283282 const float *inptr_xi = inptr + Wi * yi_ + xi_;
0 commit comments