Skip to content

Commit 9b82fa0

Browse files
authored
Merge branch 'main' into ci-144167
2 parents 0ab2195 + d0aea54 commit 9b82fa0

11 files changed

Lines changed: 906 additions & 33 deletions

File tree

docs/changelog/144649.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
area: Vector Search
2+
issues: []
3+
pr: 144649
4+
summary: "[Native] `int4` x86 SIMD optimizations"
5+
type: enhancement

libs/native/libraries/build.gradle

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ configurations {
1919
}
2020

2121
var zstdVersion = "1.5.7"
22-
var vecVersion = "1.0.58"
22+
var vecVersion = "1.0.62"
2323

2424
repositories {
2525
exclusiveContent {

libs/simdvec/native/publish_vec_binaries.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ if [ -z "$ARTIFACTORY_API_KEY" ]; then
2020
exit 1;
2121
fi
2222

23-
VERSION="1.0.58"
23+
VERSION="1.0.62"
2424
ARTIFACTORY_REPOSITORY="${ARTIFACTORY_REPOSITORY:-https://artifactory.elastic.dev/artifactory/elasticsearch-native/}"
2525
TEMP=$(mktemp -d)
2626

libs/simdvec/native/src/vec/c/amd64/vec_i4_1.cpp

Lines changed: 60 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -22,31 +22,46 @@
2222
static inline int32_t doti4_inner(const int8_t* query, const int8_t* doc, int32_t packed_len) {
2323
const __m256i mask_half_byte = _mm256_set1_epi8(0x0F);
2424
const __m256i ones = _mm256_set1_epi16(1);
25-
__m256i acc_high = _mm256_setzero_si256();
26-
__m256i acc_low = _mm256_setzero_si256();
25+
__m256i acc = _mm256_setzero_si256();
2726

2827
constexpr int stride = sizeof(__m256i);
2928
const int blk = packed_len & ~(stride - 1);
3029

31-
for (int i = 0; i < blk; i += stride) {
32-
__m256i doc_bytes = _mm256_loadu_si256((const __m256i*)(doc + i));
30+
// maddubs with int4 values produces at most 15*15+15*15 = 450 per 16-bit lane.
31+
// Safe to accumulate floor(32767/450) = 72 iterations before signed 16-bit overflow.
32+
constexpr int chunk = 64 * stride;
3333

34-
// Extract nibbles at 256-bit width.
35-
// _mm256_srli_epi16 shifts 16-bit lanes; the 0x0F mask cleans cross-byte leakage.
36-
__m256i doc_high = _mm256_and_si256(_mm256_srli_epi16(doc_bytes, 4), mask_half_byte);
37-
__m256i doc_low = _mm256_and_si256(doc_bytes, mask_half_byte);
34+
int i = 0;
35+
while (i < blk) {
36+
__m256i acc_high16 = _mm256_setzero_si256();
37+
__m256i acc_low16 = _mm256_setzero_si256();
38+
const int end = std::min(i + chunk, blk);
3839

39-
__m256i query_high = _mm256_loadu_si256((const __m256i*)(query + i));
40-
__m256i query_low = _mm256_loadu_si256((const __m256i*)(query + i + packed_len));
40+
for (; i < end; i += stride) {
41+
__m256i doc_bytes = _mm256_loadu_si256((const __m256i*)(doc + i));
4142

42-
// _mm256_maddubs_epi16 multiplies unsigned*signed byte pairs and horizontally
43-
// adds adjacent products into 16-bit results. Both operands are in [0,15] so
44-
// signedness doesn't matter. _mm256_madd_epi16 with ones reduces to 32-bit.
45-
acc_high = _mm256_add_epi32(acc_high, _mm256_madd_epi16(ones, _mm256_maddubs_epi16(doc_high, query_high)));
46-
acc_low = _mm256_add_epi32(acc_low, _mm256_madd_epi16(ones, _mm256_maddubs_epi16(doc_low, query_low)));
43+
// Extract nibbles at 256-bit width.
44+
// _mm256_srli_epi16 shifts 16-bit lanes; the 0x0F mask cleans cross-byte leakage.
45+
__m256i doc_high = _mm256_and_si256(_mm256_srli_epi16(doc_bytes, 4), mask_half_byte);
46+
__m256i doc_low = _mm256_and_si256(doc_bytes, mask_half_byte);
47+
48+
__m256i query_high = _mm256_loadu_si256((const __m256i*)(query + i));
49+
__m256i query_low = _mm256_loadu_si256((const __m256i*)(query + i + packed_len));
50+
51+
// _mm256_maddubs_epi16 multiplies unsigned*signed byte pairs and horizontally
52+
// adds adjacent products into 16-bit results. Both operands are in [0,15] so
53+
// signedness doesn't matter. Accumulate in 16-bit; widen to 32-bit after the chunk.
54+
acc_high16 = _mm256_add_epi16(acc_high16, _mm256_maddubs_epi16(doc_high, query_high));
55+
acc_low16 = _mm256_add_epi16(acc_low16, _mm256_maddubs_epi16(doc_low, query_low));
56+
}
57+
58+
// Widen 16→32 bit: _mm256_madd_epi16 with ones horizontally adds pairs of
59+
// signed 16-bit values into 32-bit results, then accumulate into the 32-bit total.
60+
acc = _mm256_add_epi32(acc, _mm256_madd_epi16(ones, acc_high16));
61+
acc = _mm256_add_epi32(acc, _mm256_madd_epi16(ones, acc_low16));
4762
}
4863

49-
int32_t total = mm256_reduce_epi32<_mm_add_epi32>(_mm256_add_epi32(acc_high, acc_low));
64+
int32_t total = mm256_reduce_epi32<_mm_add_epi32>(acc);
5065

5166
for (int i = blk; i < packed_len; i++) {
5267
uint8_t doc_byte = (uint8_t)doc[i];
@@ -74,6 +89,7 @@ static inline void doti4_bulk_impl(
7489
const __m256i ones = _mm256_set1_epi16(1);
7590
constexpr int stride = sizeof(__m256i);
7691
const int blk = packed_len & ~(stride - 1);
92+
constexpr int chunk = 64 * stride;
7793
const int lines_to_fetch = packed_len / CACHE_LINE_SIZE + 1;
7894

7995
int c = 0;
@@ -88,31 +104,45 @@ static inline void doti4_bulk_impl(
88104
prefetch(next_doc_ptrs[I], lines_to_fetch);
89105
});
90106

91-
__m256i acc_high[batches];
92-
__m256i acc_low[batches];
107+
__m256i acc32[batches];
93108
apply_indexed<batches>([&](auto I) {
94-
acc_high[I] = _mm256_setzero_si256();
95-
acc_low[I] = _mm256_setzero_si256();
109+
acc32[I] = _mm256_setzero_si256();
96110
});
97111

98112
int i = 0;
99-
for (; i < blk; i += stride) {
100-
__m256i query_high = _mm256_loadu_si256((const __m256i*)(query + i));
101-
__m256i query_low = _mm256_loadu_si256((const __m256i*)(query + i + packed_len));
102-
113+
while (i < blk) {
114+
__m256i acc_high16[batches];
115+
__m256i acc_low16[batches];
103116
apply_indexed<batches>([&](auto I) {
104-
__m256i doc_bytes = _mm256_loadu_si256((const __m256i*)(current_doc_ptrs[I] + i));
105-
__m256i doc_high = _mm256_and_si256(_mm256_srli_epi16(doc_bytes, 4), mask_half_byte);
106-
__m256i doc_low = _mm256_and_si256(doc_bytes, mask_half_byte);
117+
acc_high16[I] = _mm256_setzero_si256();
118+
acc_low16[I] = _mm256_setzero_si256();
119+
});
120+
121+
const int end = std::min(i + chunk, blk);
122+
123+
for (; i < end; i += stride) {
124+
__m256i query_high = _mm256_loadu_si256((const __m256i*)(query + i));
125+
__m256i query_low = _mm256_loadu_si256((const __m256i*)(query + i + packed_len));
107126

108-
acc_high[I] = _mm256_add_epi32(acc_high[I], _mm256_madd_epi16(ones, _mm256_maddubs_epi16(doc_high, query_high)));
109-
acc_low[I] = _mm256_add_epi32(acc_low[I], _mm256_madd_epi16(ones, _mm256_maddubs_epi16(doc_low, query_low)));
127+
apply_indexed<batches>([&](auto I) {
128+
__m256i doc_bytes = _mm256_loadu_si256((const __m256i*)(current_doc_ptrs[I] + i));
129+
__m256i doc_high = _mm256_and_si256(_mm256_srli_epi16(doc_bytes, 4), mask_half_byte);
130+
__m256i doc_low = _mm256_and_si256(doc_bytes, mask_half_byte);
131+
132+
acc_high16[I] = _mm256_add_epi16(acc_high16[I], _mm256_maddubs_epi16(doc_high, query_high));
133+
acc_low16[I] = _mm256_add_epi16(acc_low16[I], _mm256_maddubs_epi16(doc_low, query_low));
134+
});
135+
}
136+
137+
apply_indexed<batches>([&](auto I) {
138+
acc32[I] = _mm256_add_epi32(acc32[I], _mm256_madd_epi16(ones, acc_high16[I]));
139+
acc32[I] = _mm256_add_epi32(acc32[I], _mm256_madd_epi16(ones, acc_low16[I]));
110140
});
111141
}
112142

113143
int32_t res[batches];
114144
apply_indexed<batches>([&](auto I) {
115-
res[I] = mm256_reduce_epi32<_mm_add_epi32>(_mm256_add_epi32(acc_high[I], acc_low[I]));
145+
res[I] = mm256_reduce_epi32<_mm_add_epi32>(acc32[I]);
116146
});
117147

118148
for (; i < packed_len; i++) {
Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
// AVX-512 vectorized int4 packed-nibble vector operations.
11+
// The "unpacked" vector has 2*packed_len bytes (high nibbles in [0..packed_len),
12+
// low nibbles in [packed_len..2*packed_len)). The "packed" vector has packed_len
13+
// bytes, each holding two 4-bit values.
14+
15+
#include <stddef.h>
16+
#include <stdint.h>
17+
18+
#ifdef __clang__
19+
#pragma clang attribute push(__attribute__((target("arch=icelake-client"))), apply_to=function)
20+
#elif __GNUC__
21+
#pragma GCC push_options
22+
#pragma GCC target ("arch=icelake-client")
23+
#endif
24+
25+
#include "vec.h"
26+
#include "vec_common.h"
27+
#include "amd64/amd64_vec_common.h"
28+
29+
static inline int32_t doti4_inner_avx512(const int8_t* query, const int8_t* doc, int32_t packed_len) {
30+
const __m512i mask_half_byte = _mm512_set1_epi8(0x0F);
31+
const __m512i ones = _mm512_set1_epi16(1);
32+
__m512i acc = _mm512_setzero_si512();
33+
34+
constexpr int stride = sizeof(__m512i);
35+
const int blk = packed_len & ~(stride - 1);
36+
37+
// maddubs with int4 values produces at most 15*15+15*15 = 450 per 16-bit lane.
38+
// Safe to accumulate floor(32767/450) = 72 iterations before signed 16-bit overflow.
39+
constexpr int chunk = 64 * stride;
40+
41+
int i = 0;
42+
while (i < blk) {
43+
__m512i acc_high16 = _mm512_setzero_si512();
44+
__m512i acc_low16 = _mm512_setzero_si512();
45+
const int end_raw = i + chunk;
46+
// TODO: replace with std::min when we solve the gcc #pragma target inline bug
47+
const int end = end_raw < blk ? end_raw : blk;
48+
49+
for (; i < end; i += stride) {
50+
__m512i doc_bytes = _mm512_loadu_si512((const __m512i*)(doc + i));
51+
52+
__m512i doc_high = _mm512_and_si512(_mm512_srli_epi16(doc_bytes, 4), mask_half_byte);
53+
__m512i doc_low = _mm512_and_si512(doc_bytes, mask_half_byte);
54+
55+
__m512i query_high = _mm512_loadu_si512((const __m512i*)(query + i));
56+
__m512i query_low = _mm512_loadu_si512((const __m512i*)(query + i + packed_len));
57+
58+
acc_high16 = _mm512_add_epi16(acc_high16, _mm512_maddubs_epi16(doc_high, query_high));
59+
acc_low16 = _mm512_add_epi16(acc_low16, _mm512_maddubs_epi16(doc_low, query_low));
60+
}
61+
62+
acc = _mm512_add_epi32(acc, _mm512_madd_epi16(ones, acc_high16));
63+
acc = _mm512_add_epi32(acc, _mm512_madd_epi16(ones, acc_low16));
64+
}
65+
66+
int32_t total = _mm512_reduce_add_epi32(acc);
67+
68+
// Masked tail: handle remaining bytes that don't fill a full 512-bit register.
69+
// Masked-off lanes load as zero, contributing nothing to the dot product.
70+
const int rem = packed_len - blk;
71+
if (rem > 0) {
72+
__mmask64 mask = (__mmask64)((1ULL << rem) - 1);
73+
74+
__m512i doc_bytes = _mm512_maskz_loadu_epi8(mask, doc + blk);
75+
__m512i doc_high = _mm512_and_si512(_mm512_srli_epi16(doc_bytes, 4), mask_half_byte);
76+
__m512i doc_low = _mm512_and_si512(doc_bytes, mask_half_byte);
77+
78+
__m512i query_high = _mm512_maskz_loadu_epi8(mask, query + blk);
79+
__m512i query_low = _mm512_maskz_loadu_epi8(mask, query + blk + packed_len);
80+
81+
__m512i wide = _mm512_add_epi32(
82+
_mm512_madd_epi16(ones, _mm512_maddubs_epi16(doc_high, query_high)),
83+
_mm512_madd_epi16(ones, _mm512_maddubs_epi16(doc_low, query_low))
84+
);
85+
total += _mm512_reduce_add_epi32(wide);
86+
}
87+
88+
return total;
89+
}
90+
91+
EXPORT int32_t vec_doti4_2(const int8_t* query, const int8_t* doc, int32_t packed_len) {
92+
return doti4_inner_avx512(query, doc, packed_len);
93+
}
94+
95+
// batches=2 rather than 4: most CPUs have only 1 port for 512-bit integer
96+
// multiply (vpmaddubsw zmm), so batches>2 saturates that port without
97+
// increasing per-doc throughput, while adding instruction overhead.
98+
template <const int8_t*(*mapper)(const int8_t*, const int32_t, const int32_t*, const int32_t), int batches = 2>
99+
static inline void doti4_bulk_impl_avx512(
100+
const int8_t* docs,
101+
const int8_t* query,
102+
int32_t packed_len,
103+
int32_t pitch,
104+
const int32_t* offsets,
105+
int32_t count,
106+
f32_t* results
107+
) {
108+
const __m512i mask_half_byte = _mm512_set1_epi8(0x0F);
109+
const __m512i ones = _mm512_set1_epi16(1);
110+
constexpr int stride = sizeof(__m512i);
111+
const int blk = packed_len & ~(stride - 1);
112+
constexpr int chunk = 64 * stride;
113+
const int lines_to_fetch = packed_len / CACHE_LINE_SIZE + 1;
114+
115+
const int rem = packed_len - blk;
116+
const __mmask64 tail_mask = rem > 0 ? (__mmask64)((1ULL << rem) - 1) : 0;
117+
118+
int c = 0;
119+
120+
const int8_t* current_doc_ptrs[batches];
121+
init_pointers<batches, int8_t, int8_t, mapper>(current_doc_ptrs, docs, pitch, offsets, 0, count);
122+
123+
for (; c + batches - 1 < count; c += batches) {
124+
const int8_t* next_doc_ptrs[batches];
125+
const bool has_next = c + 2 * batches - 1 < count;
126+
if (has_next) {
127+
apply_indexed<batches>([&](auto I) {
128+
next_doc_ptrs[I] = mapper(docs, c + batches + I, offsets, pitch);
129+
prefetch(next_doc_ptrs[I], lines_to_fetch);
130+
});
131+
}
132+
133+
__m512i acc32[batches];
134+
apply_indexed<batches>([&](auto I) {
135+
acc32[I] = _mm512_setzero_si512();
136+
});
137+
138+
int i = 0;
139+
while (i < blk) {
140+
__m512i acc_high16[batches];
141+
__m512i acc_low16[batches];
142+
apply_indexed<batches>([&](auto I) {
143+
acc_high16[I] = _mm512_setzero_si512();
144+
acc_low16[I] = _mm512_setzero_si512();
145+
});
146+
147+
const int end_raw = i + chunk;
148+
// TODO: replace with std::min when we solve the gcc #pragma target inline bug
149+
const int end = end_raw < blk ? end_raw : blk;
150+
151+
for (; i < end; i += stride) {
152+
__m512i query_high = _mm512_loadu_si512((const __m512i*)(query + i));
153+
__m512i query_low = _mm512_loadu_si512((const __m512i*)(query + i + packed_len));
154+
155+
apply_indexed<batches>([&](auto I) {
156+
__m512i doc_bytes = _mm512_loadu_si512((const __m512i*)(current_doc_ptrs[I] + i));
157+
__m512i doc_high = _mm512_and_si512(_mm512_srli_epi16(doc_bytes, 4), mask_half_byte);
158+
__m512i doc_low = _mm512_and_si512(doc_bytes, mask_half_byte);
159+
160+
acc_high16[I] = _mm512_add_epi16(acc_high16[I], _mm512_maddubs_epi16(doc_high, query_high));
161+
acc_low16[I] = _mm512_add_epi16(acc_low16[I], _mm512_maddubs_epi16(doc_low, query_low));
162+
});
163+
}
164+
165+
apply_indexed<batches>([&](auto I) {
166+
acc32[I] = _mm512_add_epi32(acc32[I], _mm512_madd_epi16(ones, acc_high16[I]));
167+
acc32[I] = _mm512_add_epi32(acc32[I], _mm512_madd_epi16(ones, acc_low16[I]));
168+
});
169+
}
170+
171+
int32_t res[batches];
172+
apply_indexed<batches>([&](auto I) {
173+
res[I] = _mm512_reduce_add_epi32(acc32[I]);
174+
});
175+
176+
if (tail_mask) {
177+
__m512i query_high = _mm512_maskz_loadu_epi8(tail_mask, query + blk);
178+
__m512i query_low = _mm512_maskz_loadu_epi8(tail_mask, query + blk + packed_len);
179+
180+
apply_indexed<batches>([&](auto I) {
181+
__m512i doc_bytes = _mm512_maskz_loadu_epi8(tail_mask, current_doc_ptrs[I] + blk);
182+
__m512i doc_high = _mm512_and_si512(_mm512_srli_epi16(doc_bytes, 4), mask_half_byte);
183+
__m512i doc_low = _mm512_and_si512(doc_bytes, mask_half_byte);
184+
185+
__m512i wide = _mm512_add_epi32(
186+
_mm512_madd_epi16(ones, _mm512_maddubs_epi16(doc_high, query_high)),
187+
_mm512_madd_epi16(ones, _mm512_maddubs_epi16(doc_low, query_low))
188+
);
189+
res[I] += _mm512_reduce_add_epi32(wide);
190+
});
191+
}
192+
193+
// TODO: consider replacing with std::copy_n when we solve the gcc #pragma target inline bug
194+
apply_indexed<batches>([&](auto I) {
195+
results[c + I] = (f32_t)res[I];
196+
});
197+
if (has_next) {
198+
apply_indexed<batches>([&](auto I) {
199+
current_doc_ptrs[I] = next_doc_ptrs[I];
200+
});
201+
}
202+
}
203+
204+
for (; c < count; c++) {
205+
const int8_t* doc = mapper(docs, c, offsets, pitch);
206+
results[c] = (f32_t)doti4_inner_avx512(query, doc, packed_len);
207+
}
208+
}
209+
210+
EXPORT void vec_doti4_bulk_2(const int8_t* docs, const int8_t* query, int32_t packed_len, int32_t count, f32_t* results) {
211+
doti4_bulk_impl_avx512<sequential_mapper>(docs, query, packed_len, packed_len, NULL, count, results);
212+
}
213+
214+
EXPORT void vec_doti4_bulk_offsets_2(
215+
const int8_t* docs,
216+
const int8_t* query,
217+
int32_t packed_len,
218+
int32_t pitch,
219+
const int32_t* offsets,
220+
int32_t count,
221+
f32_t* results
222+
) {
223+
doti4_bulk_impl_avx512<offsets_mapper>(docs, query, packed_len, pitch, offsets, count, results);
224+
}
225+
226+
#ifdef __clang__
227+
#pragma clang attribute pop
228+
#elif __GNUC__
229+
#pragma GCC pop_options
230+
#endif

0 commit comments

Comments
 (0)