2727#include " common/object/ob_object.h"
2828#include " sql/session/ob_sql_session_mgr.h"
2929#include " sql/engine/expr/ob_expr_vector.h"
30+ #include " sql/engine/expr/ob_expr_vector_similarity.h"
3031#include " share/allocator/ob_tenant_vector_allocator.h"
3132
3233namespace oceanbase {
@@ -356,8 +357,8 @@ class ObVectorCenterClusterHelper
356357{
357358
358359public:
359- ObVectorCenterClusterHelper (ObIAllocator &allocator, const VEC_T *const_vec, oceanbase::sql::ObExprVectorDistance::ObVecDisType dis_type, int64_t dim, int64_t nprobe, float distance_threshold )
360- : alloc_(allocator), const_vec_(const_vec), dis_type_(dis_type), dim_(dim), nprobe_(nprobe), compare_(dis_type), heap_(compare_), distance_threshold_(distance_threshold )
360+ ObVectorCenterClusterHelper (ObIAllocator &allocator, const VEC_T *const_vec, oceanbase::sql::ObExprVectorDistance::ObVecDisType dis_type, int64_t dim, int64_t nprobe, float similarity_threshold )
361+ : alloc_(allocator), const_vec_(const_vec), dis_type_(dis_type), dim_(dim), nprobe_(nprobe), compare_(dis_type), heap_(compare_), similarity_threshold_(similarity_threshold )
361362 {}
362363
363364 int push_center (const CENTER_T ¢er, VEC_T *center_vec, const int64_t dim, CenterSaveMode center_save_mode = NOT_SAVE_CENTER_VEC );
@@ -368,6 +369,7 @@ class ObVectorCenterClusterHelper
368369 int get_nearest_probe_centers (ObIArray<std::pair<CENTER_T , VEC_T *>> ¢er_ids);
369370 int get_nearest_probe_centers_vec_dist (ObIArray<std::pair<CENTER_T , VEC_T *>> ¢er_ids,
370371 ObIArray<float > &distances);
372+ bool is_satify_similarity_threshold (const double & distance);
371373 int64_t get_center_count () const {
372374 return heap_.count ();
373375 }
@@ -422,7 +424,7 @@ class ObVectorCenterClusterHelper
422424 int64_t nprobe_;
423425 HeapCompare compare_;
424426 CenterHeap heap_;
425- float distance_threshold_ ;
427+ float similarity_threshold_ ;
426428};
427429
428430// ------------------ ObCentersBuffer implement ------------------
@@ -569,58 +571,55 @@ int ObVectorCenterClusterHelper<VEC_T, CENTER_T>::push_center(
569571 VEC_T *center_vec /* = nullptr*/ )
570572{
571573 int ret = OB_SUCCESS ;
572- if (distance > distance_threshold_) {
574+ if (heap_.count () < nprobe_) {
575+ void *ptr = alloc_.alloc (sizeof (ObCenterWithBuf<CENTER_T >));
576+ if (NULL == ptr) {
577+ ret = common::OB_ALLOCATE_MEMORY_FAILED ;
578+ SHARE_LOG (WARN , " no memory for table entity" , K (ret));
579+ } else {
580+ ObCenterWithBuf<CENTER_T > *center_with_buf = new (ptr) ObCenterWithBuf<CENTER_T >(&alloc_);
581+ if (OB_ISNULL (center_with_buf)) {
582+ ret = OB_ERR_UNEXPECTED ;
583+ SHARE_LOG (WARN , " center_entity is null" , K (ret));
584+ } else if (OB_FAIL (center_with_buf->new_from_src (center))) {
585+ SHARE_LOG (WARN , " center_entity fail init" , K (ret));
586+ } else {
587+ HeapCenterItemTemp item (distance, center_with_buf);
588+ if (center_save_mode == DEEP_COPY_CENTER_VEC && OB_FAIL (item.vec_dim_ .new_from_src (alloc_, center_vec, dim_))) {
589+ SHARE_LOG (WARN , " failed to new from src" , K (ret), K (center_vec));
590+ } else if (center_save_mode == SHALLOW_COPY_CENTER_VEC && OB_FALSE_IT (item.vec_dim_ .vec_ = center_vec)) {
591+ }
592+ if (OB_FAIL (ret)) {
593+ } else if (OB_FAIL (heap_.push (item))) {
594+ SHARE_LOG (WARN , " failed to push center heap" , K (ret), K (center), K (distance));
595+ }
596+ }
597+ }
573598 } else {
574- if (heap_.count () < nprobe_) {
575- void *ptr = alloc_.alloc (sizeof (ObCenterWithBuf<CENTER_T >));
576- if (NULL == ptr) {
577- ret = common::OB_ALLOCATE_MEMORY_FAILED ;
578- SHARE_LOG (WARN , " no memory for table entity" , K (ret));
599+ const HeapCenterItemTemp &top = heap_.top ();
600+ ObCenterWithBuf<CENTER_T > tmp_center_with_buf;
601+ HeapCenterItemTemp tmp (distance, &tmp_center_with_buf);
602+ if (compare_ (tmp, top)) {
603+ ObCenterWithBuf<CENTER_T > *old_center_with_buf = top.center_with_buf_ ;
604+ if (OB_ISNULL (old_center_with_buf)) {
605+ ret = OB_ERR_UNEXPECTED ;
606+ SHARE_LOG (WARN , " center_with_buf is null" , K (ret));
607+ } else if (OB_FAIL (old_center_with_buf->new_from_src (center))) {
608+ SHARE_LOG (WARN , " failed to new from src" , K (ret), K (center));
579609 } else {
580- ObCenterWithBuf<CENTER_T > *center_with_buf = new (ptr) ObCenterWithBuf<CENTER_T >(&alloc_);
581- if (OB_ISNULL (center_with_buf)) {
582- ret = OB_ERR_UNEXPECTED ;
583- SHARE_LOG (WARN , " center_entity is null" , K (ret));
584- } else if (OB_FAIL (center_with_buf->new_from_src (center))) {
585- SHARE_LOG (WARN , " center_entity fail init" , K (ret));
586- } else {
587- HeapCenterItemTemp item (distance, center_with_buf);
588- if (center_save_mode == DEEP_COPY_CENTER_VEC && OB_FAIL (item.vec_dim_ .new_from_src (alloc_, center_vec, dim_))) {
610+ HeapCenterItemTemp new_top (distance, old_center_with_buf);
611+ if (center_save_mode == DEEP_COPY_CENTER_VEC ) {
612+ new_top.set_vec_dim (top.vec_dim_ );
613+ if (OB_FAIL (new_top.vec_dim_ .reuse_from_src (center_vec, dim_))) {
589614 SHARE_LOG (WARN , " failed to new from src" , K (ret), K (center_vec));
590- } else if (center_save_mode == SHALLOW_COPY_CENTER_VEC && OB_FALSE_IT (item.vec_dim_ .vec_ = center_vec)) {
591- }
592- if (OB_FAIL (ret)) {
593- } else if (OB_FAIL (heap_.push (item))) {
594- SHARE_LOG (WARN , " failed to push center heap" , K (ret), K (center), K (distance));
595615 }
616+ } else if (center_save_mode == SHALLOW_COPY_CENTER_VEC ) {
617+ new_top.set_vec_dim (top.vec_dim_ );
618+ new_top.vec_dim_ .vec_ = center_vec;
596619 }
597- }
598- } else {
599- const HeapCenterItemTemp &top = heap_.top ();
600- ObCenterWithBuf<CENTER_T > tmp_center_with_buf;
601- HeapCenterItemTemp tmp (distance, &tmp_center_with_buf);
602- if (compare_ (tmp, top)) {
603- ObCenterWithBuf<CENTER_T > *old_center_with_buf = top.center_with_buf_ ;
604- if (OB_ISNULL (old_center_with_buf)) {
605- ret = OB_ERR_UNEXPECTED ;
606- SHARE_LOG (WARN , " center_with_buf is null" , K (ret));
607- } else if (OB_FAIL (old_center_with_buf->new_from_src (center))) {
608- SHARE_LOG (WARN , " failed to new from src" , K (ret), K (center));
609- } else {
610- HeapCenterItemTemp new_top (distance, old_center_with_buf);
611- if (center_save_mode == DEEP_COPY_CENTER_VEC ) {
612- new_top.set_vec_dim (top.vec_dim_ );
613- if (OB_FAIL (new_top.vec_dim_ .reuse_from_src (center_vec, dim_))) {
614- SHARE_LOG (WARN , " failed to new from src" , K (ret), K (center_vec));
615- }
616- } else if (center_save_mode == SHALLOW_COPY_CENTER_VEC ) {
617- new_top.set_vec_dim (top.vec_dim_ );
618- new_top.vec_dim_ .vec_ = center_vec;
619- }
620- if (OB_FAIL (ret)) {
621- } else if (OB_FAIL (heap_.replace_top (new_top))) {
622- SHARE_LOG (WARN , " failed to replace top" , K (ret), K (new_top));
623- }
620+ if (OB_FAIL (ret)) {
621+ } else if (OB_FAIL (heap_.replace_top (new_top))) {
622+ SHARE_LOG (WARN , " failed to replace top" , K (ret), K (new_top));
624623 }
625624 }
626625 }
@@ -636,12 +635,14 @@ int ObVectorCenterClusterHelper<VEC_T, CENTER_T>::get_nearest_probe_center_ids(O
636635 ret = OB_ERR_UNEXPECTED ;
637636 SHARE_LOG (WARN , " max heap count is not equal to nprobe" , K (ret), K (heap_.count ()), K (nprobe_));
638637 }
638+ bool is_satify_distance_threshold = false ;
639639 while (OB_SUCC (ret) && !heap_.empty ()) {
640640 const HeapCenterItemTemp &cur_top = heap_.top ();
641641 if (OB_ISNULL (cur_top.center_with_buf_ )) {
642642 ret = OB_ERR_UNEXPECTED ;
643643 SHARE_LOG (WARN , " center_with_buf is null" , K (ret));
644- } else if (OB_FAIL (center_ids.push_back (cur_top.center_with_buf_ ->get_center ()))) {
644+ } else if (OB_FALSE_IT (is_satify_distance_threshold = is_satify_similarity_threshold (cur_top.distance_ ))) {
645+ } else if ( is_satify_distance_threshold && OB_FAIL (center_ids.push_back (cur_top.center_with_buf_ ->get_center ()))) {
645646 ret = OB_ERR_UNEXPECTED ;
646647 SHARE_LOG (WARN , " failed to push center id" , K (ret), K (cur_top.center_with_buf_ ->get_center ()));
647648 } else if (OB_FAIL (heap_.pop ())) {
@@ -656,6 +657,20 @@ int ObVectorCenterClusterHelper<VEC_T, CENTER_T>::get_nearest_probe_center_ids(O
656657 return ret;
657658}
658659
660+ template <typename VEC_T , typename CENTER_T >
661+ bool ObVectorCenterClusterHelper<VEC_T , CENTER_T >::is_satify_similarity_threshold(const double & distance)
662+ {
663+ bool is_satify = true ;
664+ int ret = OB_SUCCESS ;
665+ float similarity = 0.0 ;
666+ if (similarity_threshold_ != 0.0 && OB_FAIL (oceanbase::sql::ObExprVectorSimilarity::calc_similarity_from_distance (dis_type_, distance, similarity))){
667+ SHARE_LOG (WARN , " get similarity from distance fail" , K (ret));
668+ } else if (similarity < similarity_threshold_){
669+ is_satify = false ;
670+ }
671+ return is_satify;
672+ }
673+
659674template <typename VEC_T , typename CENTER_T >
660675int ObVectorCenterClusterHelper<VEC_T , CENTER_T >::get_nearest_probe_center_ids_dist(ObArrayWrap<bool > &nearest_cid_dist)
661676{
0 commit comments