Skip to content

Commit f443e5a

Browse files
obdevfootka
authored andcommitted
[CP] fix ivf similarity lost data
Co-authored-by: footka <672528926@qq.com>
1 parent ae25697 commit f443e5a

7 files changed

Lines changed: 110 additions & 78 deletions

File tree

src/share/vector_type/ob_vector_common_util.h

Lines changed: 66 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "common/object/ob_object.h"
2828
#include "sql/session/ob_sql_session_mgr.h"
2929
#include "sql/engine/expr/ob_expr_vector.h"
30+
#include "sql/engine/expr/ob_expr_vector_similarity.h"
3031
#include "share/allocator/ob_tenant_vector_allocator.h"
3132

3233
namespace oceanbase {
@@ -356,8 +357,8 @@ class ObVectorCenterClusterHelper
356357
{
357358

358359
public:
359-
ObVectorCenterClusterHelper(ObIAllocator &allocator, const VEC_T *const_vec, oceanbase::sql::ObExprVectorDistance::ObVecDisType dis_type, int64_t dim, int64_t nprobe, float distance_threshold)
360-
: alloc_(allocator), const_vec_(const_vec), dis_type_(dis_type), dim_(dim), nprobe_(nprobe), compare_(dis_type), heap_(compare_), distance_threshold_(distance_threshold)
360+
ObVectorCenterClusterHelper(ObIAllocator &allocator, const VEC_T *const_vec, oceanbase::sql::ObExprVectorDistance::ObVecDisType dis_type, int64_t dim, int64_t nprobe, float similarity_threshold)
361+
: alloc_(allocator), const_vec_(const_vec), dis_type_(dis_type), dim_(dim), nprobe_(nprobe), compare_(dis_type), heap_(compare_), similarity_threshold_(similarity_threshold)
361362
{}
362363

363364
int push_center(const CENTER_T &center, VEC_T *center_vec, const int64_t dim, CenterSaveMode center_save_mode = NOT_SAVE_CENTER_VEC);
@@ -368,6 +369,7 @@ class ObVectorCenterClusterHelper
368369
int get_nearest_probe_centers(ObIArray<std::pair<CENTER_T, VEC_T *>> &center_ids);
369370
int get_nearest_probe_centers_vec_dist(ObIArray<std::pair<CENTER_T, VEC_T *>> &center_ids,
370371
ObIArray<float> &distances);
372+
bool is_satify_similarity_threshold(const double& distance);
371373
int64_t get_center_count() const {
372374
return heap_.count();
373375
}
@@ -422,7 +424,7 @@ class ObVectorCenterClusterHelper
422424
int64_t nprobe_;
423425
HeapCompare compare_;
424426
CenterHeap heap_;
425-
float distance_threshold_;
427+
float similarity_threshold_;
426428
};
427429

428430
// ------------------ ObCentersBuffer implement ------------------
@@ -569,58 +571,55 @@ int ObVectorCenterClusterHelper<VEC_T, CENTER_T>::push_center(
569571
VEC_T *center_vec /*= nullptr*/)
570572
{
571573
int ret = OB_SUCCESS;
572-
if (distance > distance_threshold_) {
574+
if (heap_.count() < nprobe_) {
575+
void *ptr = alloc_.alloc(sizeof(ObCenterWithBuf<CENTER_T>));
576+
if (NULL == ptr) {
577+
ret = common::OB_ALLOCATE_MEMORY_FAILED;
578+
SHARE_LOG(WARN, "no memory for table entity", K(ret));
579+
} else {
580+
ObCenterWithBuf<CENTER_T> *center_with_buf = new (ptr) ObCenterWithBuf<CENTER_T>(&alloc_);
581+
if (OB_ISNULL(center_with_buf)) {
582+
ret = OB_ERR_UNEXPECTED;
583+
SHARE_LOG(WARN, "center_entity is null", K(ret));
584+
} else if (OB_FAIL(center_with_buf->new_from_src(center))) {
585+
SHARE_LOG(WARN, "center_entity fail init", K(ret));
586+
} else {
587+
HeapCenterItemTemp item(distance, center_with_buf);
588+
if (center_save_mode == DEEP_COPY_CENTER_VEC && OB_FAIL(item.vec_dim_.new_from_src(alloc_, center_vec, dim_))) {
589+
SHARE_LOG(WARN, "failed to new from src", K(ret), K(center_vec));
590+
} else if (center_save_mode == SHALLOW_COPY_CENTER_VEC && OB_FALSE_IT(item.vec_dim_.vec_ = center_vec)) {
591+
}
592+
if (OB_FAIL(ret)) {
593+
} else if (OB_FAIL(heap_.push(item))) {
594+
SHARE_LOG(WARN, "failed to push center heap", K(ret), K(center), K(distance));
595+
}
596+
}
597+
}
573598
} else {
574-
if (heap_.count() < nprobe_) {
575-
void *ptr = alloc_.alloc(sizeof(ObCenterWithBuf<CENTER_T>));
576-
if (NULL == ptr) {
577-
ret = common::OB_ALLOCATE_MEMORY_FAILED;
578-
SHARE_LOG(WARN, "no memory for table entity", K(ret));
599+
const HeapCenterItemTemp &top = heap_.top();
600+
ObCenterWithBuf<CENTER_T> tmp_center_with_buf;
601+
HeapCenterItemTemp tmp(distance, &tmp_center_with_buf);
602+
if (compare_(tmp, top)) {
603+
ObCenterWithBuf<CENTER_T> *old_center_with_buf = top.center_with_buf_;
604+
if (OB_ISNULL(old_center_with_buf)) {
605+
ret = OB_ERR_UNEXPECTED;
606+
SHARE_LOG(WARN, "center_with_buf is null", K(ret));
607+
} else if (OB_FAIL(old_center_with_buf->new_from_src(center))) {
608+
SHARE_LOG(WARN, "failed to new from src", K(ret), K(center));
579609
} else {
580-
ObCenterWithBuf<CENTER_T> *center_with_buf = new (ptr) ObCenterWithBuf<CENTER_T>(&alloc_);
581-
if (OB_ISNULL(center_with_buf)) {
582-
ret = OB_ERR_UNEXPECTED;
583-
SHARE_LOG(WARN, "center_entity is null", K(ret));
584-
} else if (OB_FAIL(center_with_buf->new_from_src(center))) {
585-
SHARE_LOG(WARN, "center_entity fail init", K(ret));
586-
} else {
587-
HeapCenterItemTemp item(distance, center_with_buf);
588-
if (center_save_mode == DEEP_COPY_CENTER_VEC && OB_FAIL(item.vec_dim_.new_from_src(alloc_, center_vec, dim_))) {
610+
HeapCenterItemTemp new_top(distance, old_center_with_buf);
611+
if (center_save_mode == DEEP_COPY_CENTER_VEC) {
612+
new_top.set_vec_dim(top.vec_dim_);
613+
if (OB_FAIL(new_top.vec_dim_.reuse_from_src(center_vec, dim_))) {
589614
SHARE_LOG(WARN, "failed to new from src", K(ret), K(center_vec));
590-
} else if (center_save_mode == SHALLOW_COPY_CENTER_VEC && OB_FALSE_IT(item.vec_dim_.vec_ = center_vec)) {
591-
}
592-
if (OB_FAIL(ret)) {
593-
} else if (OB_FAIL(heap_.push(item))) {
594-
SHARE_LOG(WARN, "failed to push center heap", K(ret), K(center), K(distance));
595615
}
616+
} else if (center_save_mode == SHALLOW_COPY_CENTER_VEC) {
617+
new_top.set_vec_dim(top.vec_dim_);
618+
new_top.vec_dim_.vec_ = center_vec;
596619
}
597-
}
598-
} else {
599-
const HeapCenterItemTemp &top = heap_.top();
600-
ObCenterWithBuf<CENTER_T> tmp_center_with_buf;
601-
HeapCenterItemTemp tmp(distance, &tmp_center_with_buf);
602-
if (compare_(tmp, top)) {
603-
ObCenterWithBuf<CENTER_T> *old_center_with_buf = top.center_with_buf_;
604-
if (OB_ISNULL(old_center_with_buf)) {
605-
ret = OB_ERR_UNEXPECTED;
606-
SHARE_LOG(WARN, "center_with_buf is null", K(ret));
607-
} else if (OB_FAIL(old_center_with_buf->new_from_src(center))) {
608-
SHARE_LOG(WARN, "failed to new from src", K(ret), K(center));
609-
} else {
610-
HeapCenterItemTemp new_top(distance, old_center_with_buf);
611-
if (center_save_mode == DEEP_COPY_CENTER_VEC) {
612-
new_top.set_vec_dim(top.vec_dim_);
613-
if (OB_FAIL(new_top.vec_dim_.reuse_from_src(center_vec, dim_))) {
614-
SHARE_LOG(WARN, "failed to new from src", K(ret), K(center_vec));
615-
}
616-
} else if (center_save_mode == SHALLOW_COPY_CENTER_VEC) {
617-
new_top.set_vec_dim(top.vec_dim_);
618-
new_top.vec_dim_.vec_ = center_vec;
619-
}
620-
if (OB_FAIL(ret)) {
621-
} else if (OB_FAIL(heap_.replace_top(new_top))) {
622-
SHARE_LOG(WARN, "failed to replace top", K(ret), K(new_top));
623-
}
620+
if (OB_FAIL(ret)) {
621+
} else if (OB_FAIL(heap_.replace_top(new_top))) {
622+
SHARE_LOG(WARN, "failed to replace top", K(ret), K(new_top));
624623
}
625624
}
626625
}
@@ -636,12 +635,14 @@ int ObVectorCenterClusterHelper<VEC_T, CENTER_T>::get_nearest_probe_center_ids(O
636635
ret = OB_ERR_UNEXPECTED;
637636
SHARE_LOG(WARN, "max heap count is not equal to nprobe", K(ret), K(heap_.count()), K(nprobe_));
638637
}
638+
bool is_satify_distance_threshold = false;
639639
while(OB_SUCC(ret) && !heap_.empty()) {
640640
const HeapCenterItemTemp &cur_top = heap_.top();
641641
if (OB_ISNULL(cur_top.center_with_buf_)) {
642642
ret = OB_ERR_UNEXPECTED;
643643
SHARE_LOG(WARN, "center_with_buf is null", K(ret));
644-
} else if (OB_FAIL(center_ids.push_back(cur_top.center_with_buf_->get_center()))) {
644+
} else if (OB_FALSE_IT(is_satify_distance_threshold = is_satify_similarity_threshold(cur_top.distance_))) {
645+
} else if ( is_satify_distance_threshold && OB_FAIL(center_ids.push_back(cur_top.center_with_buf_->get_center()))) {
645646
ret = OB_ERR_UNEXPECTED;
646647
SHARE_LOG(WARN, "failed to push center id", K(ret), K(cur_top.center_with_buf_->get_center()));
647648
} else if (OB_FAIL(heap_.pop())) {
@@ -656,6 +657,20 @@ int ObVectorCenterClusterHelper<VEC_T, CENTER_T>::get_nearest_probe_center_ids(O
656657
return ret;
657658
}
658659

660+
template <typename VEC_T, typename CENTER_T>
661+
bool ObVectorCenterClusterHelper<VEC_T, CENTER_T>::is_satify_similarity_threshold(const double& distance)
662+
{
663+
bool is_satify = true;
664+
int ret = OB_SUCCESS;
665+
float similarity = 0.0;
666+
if (similarity_threshold_ != 0.0 && OB_FAIL(oceanbase::sql::ObExprVectorSimilarity::calc_similarity_from_distance(dis_type_, distance, similarity))){
667+
SHARE_LOG(WARN, "get similarity from distance fail", K(ret));
668+
} else if (similarity < similarity_threshold_){
669+
is_satify = false;
670+
}
671+
return is_satify;
672+
}
673+
659674
template <typename VEC_T, typename CENTER_T>
660675
int ObVectorCenterClusterHelper<VEC_T, CENTER_T>::get_nearest_probe_center_ids_dist(ObArrayWrap<bool> &nearest_cid_dist)
661676
{

src/sql/das/iter/ob_das_ivf_scan_iter.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -272,11 +272,11 @@ int ObDASIvfBaseScanIter::inner_init(ObDASIterParam &param)
272272
LOG_WARN("build search param fail", K(vec_aux_ctdef_->vector_index_param_), K(vec_aux_ctdef_->vec_query_param_));
273273
} else {
274274
LOG_TRACE("search param", K(vec_aux_ctdef_->vector_index_param_), K(vec_aux_ctdef_->vec_query_param_), K(search_param_));
275-
276-
if (search_param_.similarity_threshold_ != 0) {
277-
if (OB_FAIL(ObDasVecScanUtils::get_distance_threshold_ivf(
278-
*sort_ctdef_->sort_exprs_[0], search_param_.similarity_threshold_, distance_threshold_))) {
279-
LOG_WARN("get distance threshold fail", K(ret));
275+
if (search_param_.similarity_threshold_ > 0) {
276+
if (OB_FAIL(ObDasVecScanUtils::check_ivf_support_similarity_threshold(*sort_ctdef_->sort_exprs_[0]))) {
277+
LOG_WARN("check support similarity threshold fail", K(ret));
278+
} else {
279+
similarity_threshold_ = search_param_.similarity_threshold_;
280280
}
281281
}
282282
}
@@ -1230,7 +1230,7 @@ int ObDASIvfScanIter::get_nearest_probe_center_ids(bool is_vectorized)
12301230
{
12311231
int ret = OB_SUCCESS;
12321232
share::ObVectorCenterClusterHelper<float, ObCenterId> nearest_cid_heap(
1233-
mem_context_->get_arena_allocator(), reinterpret_cast<const float *>(real_search_vec_.ptr()), dis_type_, dim_, nprobes_, FLT_MAX);
1233+
mem_context_->get_arena_allocator(), reinterpret_cast<const float *>(real_search_vec_.ptr()), dis_type_, dim_, nprobes_, 0.0);
12341234
if (OB_FAIL(generate_nearest_cid_heap(is_vectorized, nearest_cid_heap))) {
12351235
LOG_WARN("failed to generate nearest cid heap", K(ret), K(nprobes_), K(dim_), K(real_search_vec_));
12361236
}
@@ -1390,7 +1390,7 @@ int ObDASIvfScanIter::get_nearest_limit_rowkeys_in_cids(bool is_vectorized, T *s
13901390

13911391
int64_t enlargement_factor = (selectivity_ != 0 && selectivity_ != 1) ? POST_ENLARGEMENT_FACTOR : 1;
13921392
share::ObVectorCenterClusterHelper<T, ObRowkey> nearest_rowkey_heap(
1393-
vec_op_alloc_, serch_vec, dis_type_, dim_, get_nprobe(limit_param_, enlargement_factor), distance_threshold_);
1393+
vec_op_alloc_, serch_vec, dis_type_, dim_, get_nprobe(limit_param_, enlargement_factor), similarity_threshold_);
13941394

13951395
const ObDASScanCtDef *cid_vec_ctdef = vec_aux_ctdef_->get_vec_aux_tbl_ctdef(
13961396
vec_aux_ctdef_->get_ivf_cid_vec_tbl_idx(), ObTSCIRScanType::OB_VEC_IVF_CID_VEC_SCAN);
@@ -1567,7 +1567,7 @@ int ObDASIvfScanIter::process_ivf_scan_pre(ObIAllocator &allocator, bool is_vect
15671567
} else if (pre_fileter_rowkeys_.count() < IVF_MAX_BRUTE_FORCE_SIZE) {
15681568
// do brute search
15691569
IvfRowkeyHeap nearest_rowkey_heap(vec_op_alloc_, search_vec, raw_dis_type, dim_,
1570-
get_nprobe(limit_param_, 1), distance_threshold_);
1570+
get_nprobe(limit_param_, 1), similarity_threshold_);
15711571
if (OB_FAIL(get_rowkey_brute_post(is_vectorized, nearest_rowkey_heap))) {
15721572
LOG_WARN("failed to get limit rowkey brute", K(ret));
15731573
} else if (OB_FAIL(nearest_rowkey_heap.get_nearest_probe_center_ids(saved_rowkeys_))) {
@@ -1580,7 +1580,7 @@ int ObDASIvfScanIter::process_ivf_scan_pre(ObIAllocator &allocator, bool is_vect
15801580
// cid_center table is empty, just do brute search
15811581
ret = OB_SUCCESS;
15821582
IvfRowkeyHeap nearest_rowkey_heap(vec_op_alloc_, search_vec /*unused*/, raw_dis_type, dim_,
1583-
get_nprobe(limit_param_, 1), distance_threshold_);
1583+
get_nprobe(limit_param_, 1), similarity_threshold_);
15841584
bool index_end = false;
15851585
while (OB_SUCC(ret) && !index_end) {
15861586
if (OB_FAIL(get_pre_filter_rowkey_batch(mem_context_->get_arena_allocator(), is_vectorized, batch_row_count,
@@ -2118,7 +2118,7 @@ int ObDASIvfPQScanIter::calc_nearest_limit_rowkeys_in_cids(
21182118
int64_t ksub = 1L << nbits_;
21192119
IvfRowkeyHeap nearest_rowkey_heap(
21202120
vec_op_alloc_, search_vec/*unused*/, dis_type_, sub_dim,
2121-
get_nprobe(limit_param_, 1), distance_threshold_); // pq do not need to
2121+
get_nprobe(limit_param_, 1), similarity_threshold_); // pq do not need to
21222122
const ObDASScanCtDef *cid_vec_ctdef = vec_aux_ctdef_->get_vec_aux_tbl_ctdef(
21232123
vec_aux_ctdef_->get_ivf_cid_vec_tbl_idx(), ObTSCIRScanType::OB_VEC_IVF_CID_VEC_SCAN);
21242124
ObDASScanRtDef *cid_vec_rtdef = vec_aux_rtdef_->get_vec_aux_tbl_rtdef(vec_aux_ctdef_->get_ivf_cid_vec_tbl_idx());
@@ -2314,7 +2314,7 @@ int ObDASIvfPQScanIter::get_nearest_probe_centers(bool is_vectorized)
23142314
{
23152315
int ret = OB_SUCCESS;
23162316
share::ObVectorCenterClusterHelper<float, ObCenterId> nearest_cid_heap(
2317-
mem_context_->get_arena_allocator(), reinterpret_cast<const float *>(real_search_vec_.ptr()), dis_type_, dim_, nprobes_, FLT_MAX);
2317+
mem_context_->get_arena_allocator(), reinterpret_cast<const float *>(real_search_vec_.ptr()), dis_type_, dim_, nprobes_, 0.0);
23182318
if (OB_FAIL(generate_nearest_cid_heap(is_vectorized, nearest_cid_heap, true/*save_center_vec*/))) {
23192319
LOG_WARN("failed to generate nearest cid heap", K(ret), K(nprobes_), K(dim_), K(real_search_vec_));
23202320
}
@@ -2527,7 +2527,7 @@ int ObDASIvfPQScanIter::process_ivf_scan_post(bool is_vectorized)
25272527
if (ret == OB_ENTRY_NOT_EXIST) {
25282528
float *search_vec = reinterpret_cast<float *>(real_search_vec_.ptr());
25292529
ObExprVectorDistance::ObVecDisType raw_dis_type = !need_norm_ ? dis_type_ : ObExprVectorDistance::ObVecDisType::COSINE;
2530-
IvfRowkeyHeap nearest_rowkey_heap(vec_op_alloc_, search_vec/*unused*/, raw_dis_type, dim_, get_nprobe(limit_param_, 1), distance_threshold_);
2530+
IvfRowkeyHeap nearest_rowkey_heap(vec_op_alloc_, search_vec/*unused*/, raw_dis_type, dim_, get_nprobe(limit_param_, 1), similarity_threshold_);
25312531
if (OB_FAIL(get_rowkey_brute_post(is_vectorized, nearest_rowkey_heap))) {
25322532
LOG_WARN("failed to get limit rowkey brute", K(ret));
25332533
} else if (OB_FAIL(nearest_rowkey_heap.get_nearest_probe_center_ids(saved_rowkeys_))) {
@@ -2733,7 +2733,7 @@ int ObDASIvfPQScanIter::process_ivf_scan_pre(ObIAllocator &allocator, bool is_ve
27332733
if (ret == OB_ENTRY_NOT_EXIST) {
27342734
// cid_center table is empty, just do brute search
27352735
ret = OB_SUCCESS;
2736-
IvfRowkeyHeap nearest_rowkey_heap(vec_op_alloc_, search_vec/*unused*/, raw_dis_type, dim_, get_nprobe(limit_param_, 1), distance_threshold_);
2736+
IvfRowkeyHeap nearest_rowkey_heap(vec_op_alloc_, search_vec/*unused*/, raw_dis_type, dim_, get_nprobe(limit_param_, 1), similarity_threshold_);
27372737
bool index_end = false;
27382738
while (OB_SUCC(ret) && !index_end) {
27392739
if (OB_FAIL(get_pre_filter_rowkey_batch(mem_context_->get_arena_allocator(), is_vectorized, batch_row_count,
@@ -2757,7 +2757,7 @@ int ObDASIvfPQScanIter::process_ivf_scan_pre(ObIAllocator &allocator, bool is_ve
27572757
LOG_WARN("failed to get rowkey pre filter", K(ret), K(is_vectorized));
27582758
} else if (pre_fileter_rowkeys_.count() < IVF_MAX_BRUTE_FORCE_SIZE) {
27592759
// do brute search
2760-
IvfRowkeyHeap nearest_rowkey_heap(vec_op_alloc_, search_vec/*unused*/, raw_dis_type, dim_, get_nprobe(limit_param_, 1), distance_threshold_);
2760+
IvfRowkeyHeap nearest_rowkey_heap(vec_op_alloc_, search_vec/*unused*/, raw_dis_type, dim_, get_nprobe(limit_param_, 1), similarity_threshold_);
27612761
if (OB_FAIL(get_rowkey_brute_post(is_vectorized, nearest_rowkey_heap))) {
27622762
LOG_WARN("failed to get limit rowkey brute", K(ret));
27632763
} else if (OB_FAIL(nearest_rowkey_heap.get_nearest_probe_center_ids(saved_rowkeys_))) {

src/sql/das/iter/ob_das_ivf_scan_iter.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ class ObDASIvfBaseScanIter : public ObDASIter
200200
vec_aux_rtdef_(nullptr),
201201
saved_rowkeys_itr_(nullptr),
202202
search_param_(),
203-
distance_threshold_(FLT_MAX)
203+
similarity_threshold_(0)
204204
{
205205
dis_type_ = ObExprVectorDistance::ObVecDisType::MAX_TYPE;
206206
saved_rowkeys_.set_attr(ObMemAttr(MTL_ID(), "VecIdxKeyRanges"));
@@ -379,7 +379,7 @@ class ObDASIvfBaseScanIter : public ObDASIter
379379
common::ObSEArray<common::ObRowkey, 16> saved_rowkeys_;
380380
common::ObSEArray<common::ObRowkey, 16> pre_fileter_rowkeys_;
381381
ObVectorIndexParam search_param_;
382-
float distance_threshold_;
382+
float similarity_threshold_;
383383
};
384384

385385
class ObDASIvfScanIter : public ObDASIvfBaseScanIter

src/sql/das/iter/ob_das_vec_scan_utils.cpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -119,23 +119,17 @@ int ObDasVecScanUtils::get_distance_threshold_hnsw(ObExpr &expr,
119119
return ret;
120120
}
121121

122-
int ObDasVecScanUtils::get_distance_threshold_ivf(ObExpr &expr,
123-
float &similarity_threshold,
124-
float &distance_threshold)
122+
int ObDasVecScanUtils::check_ivf_support_similarity_threshold(ObExpr &expr)
125123
{
126124
int ret = OB_SUCCESS;
127125

128126
switch (expr.type_) {
129127
case T_FUN_SYS_L2_DISTANCE:
130-
// l2_similarity = 1 / (1 + l2_square_distance), ob use l2_distance
131-
distance_threshold = sqrt(1 / similarity_threshold - 1);
132128
break;
133129
// currently we don't support ip similarity
134130
// case T_FUN_SYS_INNER_PRODUCT:
135131
// case T_FUN_SYS_NEGATIVE_INNER_PRODUCT:
136132
case T_FUN_SYS_COSINE_DISTANCE:
137-
// cosine_similarity = (1 + cosine) / 2, ob cosine_distance = 1 - cosine
138-
distance_threshold = 2 - 2 * similarity_threshold;
139133
break;
140134
default:
141135
ret = OB_NOT_SUPPORTED;

src/sql/das/iter/ob_das_vec_scan_utils.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,7 @@ class ObDasVecScanUtils
8585
static int get_distance_threshold_hnsw(ObExpr &expr,
8686
float &similarity_threshold,
8787
float &distance_threshold);
88-
static int get_distance_threshold_ivf(ObExpr &expr,
89-
float &similarity_threshold,
90-
float &distance_threshold);
88+
static int check_ivf_support_similarity_threshold(ObExpr &expr);
9189
};
9290

9391
} // namespace sql

src/sql/engine/expr/ob_expr_vector_similarity.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,5 +204,30 @@ int ObExprVectorIPSimilarity::calc_ip_similarity(const ObExpr &expr, ObEvalCtx &
204204
return ObExprVectorSimilarity::calc_similarity(expr, ctx, res_datum, ObVecSimilarityType::DOT);
205205
}
206206

207+
int ObExprVectorSimilarity::calc_similarity_from_distance(const ObExprVectorDistance::ObVecDisType dis_type, const float &distance, float &similarity)
208+
{
209+
int ret = OB_SUCCESS;
210+
switch (dis_type) {
211+
case ObExprVectorDistance::ObVecDisType::EUCLIDEAN:
212+
// l2_similarity = 1 / (1 + l2_square_distance), ob use l2_distance
213+
similarity = 1 / (1 + distance * distance);
214+
break;
215+
// currently we don't support ip similarity
216+
case ObExprVectorDistance::ObVecDisType::DOT:
217+
similarity = (1 + distance) / 2;
218+
break;
219+
// case T_FUN_SYS_NEGATIVE_INNER_PRODUCT:
220+
case ObExprVectorDistance::ObVecDisType::COSINE:
221+
// cosine_similarity = (1 + cosine) / 2, ob cosine_distance = 1 - cosine
222+
similarity = (2 - distance) / 2;
223+
break;
224+
default:
225+
ret = OB_NOT_SUPPORTED;
226+
LOG_WARN("not support vector sort expr", K(ret), K(dis_type));
227+
break;
228+
}
229+
return ret;
230+
}
231+
207232
} // sql
208233
} // oceanbase

0 commit comments

Comments
 (0)