Skip to content

Commit 352484e

Browse files
MB-65473: Batch converter for vector to cluster IDs (#49)
- An API was introduced in #48 that returns the cluster containing a given vector in an IVF index. - This PR modifies the API to support batch processing, allowing multiple vector IDs to be provided and retrieving their corresponding cluster (list) IDs from an IVF index. This function efficiently assigns vector IDs to their respective inverted lists/clusters in a batch operation. - Fix typecasts in `faiss_Search_closest_eligible_centroids` - Ensure errors are correctly handled in the introduced API
1 parent 14a4a60 commit 352484e

File tree

4 files changed

+30
-22
lines changed

4 files changed

+30
-22
lines changed

c_api/IndexIVF_c_ex.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ int faiss_SearchParametersIVF_new_with_sel(
3636

3737
int faiss_Search_closest_eligible_centroids(
3838
FaissIndex* index,
39-
int n,
40-
float* query,
41-
int k,
39+
idx_t n,
40+
const float* query,
41+
idx_t k,
4242
float* centroid_distances,
4343
idx_t* centroid_ids,
4444
const FaissSearchParameters* params) {
@@ -52,11 +52,13 @@ int faiss_Search_closest_eligible_centroids(
5252
CATCH_AND_HANDLE
5353
}
5454

55-
idx_t faiss_get_list_for_key(
56-
FaissIndexIVF* index,
57-
idx_t key) {
55+
int faiss_get_lists_for_keys(
56+
FaissIndexIVF* index,
57+
idx_t* keys,
58+
size_t n_keys,
59+
idx_t* lists) {
5860
try {
59-
return reinterpret_cast<IndexIVF*>(index)->get_list_for_key(key);
61+
reinterpret_cast<IndexIVF*>(index)->get_lists_for_keys(keys, n_keys, lists);
6062
}
6163
CATCH_AND_HANDLE
6264
}

c_api/IndexIVF_c_ex.h

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ int faiss_SearchParametersIVF_new_with_sel(
3737
*/
3838
int faiss_Search_closest_eligible_centroids(
3939
FaissIndex* index,
40-
int n,
41-
float* query,
42-
int k,
40+
idx_t n,
41+
const float* query,
42+
idx_t k,
4343
float* centroid_distances,
4444
idx_t* centroid_ids,
4545
const FaissSearchParameters* params
@@ -91,17 +91,22 @@ int faiss_IndexIVF_compute_distance_to_codes_for_list(
9191
const uint8_t* codes,
9292
float* dists);
9393

94-
/*
95-
Given a query vector ID `key`, return the list number
96-
where the vector is stored. In the context of an IVF index,
97-
this corresponds to the cluster that contains the vector.
98-
99-
@param key - vector ID
94+
/*
95+
Given multiple vector IDs, retrieve the corresponding list (cluster) IDs
96+
from an IVF index. This function efficiently assigns vector IDs to their
97+
respective inverted lists/clusters in a batch operation.
98+
99+
@param index - Pointer to the Faiss IVF index
100+
@param keys - Input array of vector IDs (keys)
101+
@param n_keys - Number of vector keys in the input array
102+
@param lists - Output array where corresponding cluster (list) IDs are stored
100103
*/
101104

102-
idx_t faiss_get_list_for_key(
105+
int faiss_get_lists_for_keys(
103106
FaissIndexIVF* index,
104-
idx_t key);
107+
idx_t* keys,
108+
size_t n_keys,
109+
idx_t* lists);
105110

106111
#ifdef __cplusplus
107112
}

faiss/IndexIVF.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -922,9 +922,10 @@ void IndexIVF::reconstruct(idx_t key, float* recons) const {
922922
reconstruct_from_offset(lo_listno(lo), lo_offset(lo), recons);
923923
}
924924

925-
idx_t IndexIVF::get_list_for_key(idx_t key) {
926-
idx_t lo = direct_map.get(key);
927-
return lo_listno(lo);
925+
void IndexIVF::get_lists_for_keys(idx_t* keys, size_t n_keys, idx_t* lists) {
926+
for (int i = 0; i < n_keys; i++) {
927+
lists[i] = lo_listno(direct_map.get(keys[i]));
928+
}
928929
}
929930

930931
void IndexIVF::reconstruct_n(idx_t i0, idx_t ni, float* recons) const {

faiss/IndexIVF.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ struct IndexIVF : Index, IndexIVFInterface {
343343
*/
344344
void reconstruct_n(idx_t i0, idx_t ni, float* recons) const override;
345345

346-
idx_t get_list_for_key(idx_t key);
346+
void get_lists_for_keys(idx_t* keys, size_t n_keys, idx_t* lists);
347347

348348
/** Similar to search, but also reconstructs the stored vectors (or an
349349
* approximation in the case of lossy coding) for the search results.

0 commit comments

Comments
 (0)