Skip to content

Commit 8b07ab1

Browse files
hnwyllmmobdevchen20183000177
authored
[CP] fix: resolve 4016 error in semantic index refresh due to empty model_name
Co-authored-by: obdev <obdev@oceanbase.com> Co-authored-by: chen20183000177 <861521087@qq.com>
1 parent 443c03a commit 8b07ab1

3 files changed

Lines changed: 65 additions & 47 deletions

File tree

src/share/vector_index/ob_hybrid_vector_refresh_task.cpp

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "src/storage/ls/ob_ls.h"
2222
#include "src/storage/tx/ob_trans_service.h"
2323
#include "sql/das/ob_das_dml_vec_iter.h"
24+
#include "sql/engine/expr/ob_expr_ai/ob_ai_func_utils.h"
2425

2526
namespace oceanbase
2627
{
@@ -220,7 +221,7 @@ int ObHybridVectorRefreshTask::do_work()
220221
exec_finish = true;
221222
break;
222223
}
223-
default :
224+
default :
224225
ret = OB_ERR_UNEXPECTED;
225226
LOG_WARN("unexpected task status", K(ret), K(current_status()), KPC(get_task_ctx()));
226227
break;
@@ -435,9 +436,9 @@ int ObHybridVectorRefreshTask::get_embedded_table_column_ids(ObPluginVectorIndex
435436
return ret;
436437
}
437438

438-
int ObHybridVectorRefreshTask::init_dml_param(uint64_t table_id,
439-
ObDMLBaseParam &dml_param,
440-
share::schema::ObTableDMLParam &table_param,
439+
int ObHybridVectorRefreshTask::init_dml_param(uint64_t table_id,
440+
ObDMLBaseParam &dml_param,
441+
share::schema::ObTableDMLParam &table_param,
441442
ObIArray<uint64_t> &dml_column_ids,
442443
transaction::ObTxDesc *tx_desc,
443444
oceanbase::transaction::ObTxReadSnapshot &snapshot,
@@ -486,6 +487,8 @@ int ObHybridVectorRefreshTask::init_dml_param(uint64_t table_id,
486487
int ObHybridVectorRefreshTask::init_endpoint(ObPluginVectorIndexAdaptor &adaptor)
487488
{
488489
int ret = OB_SUCCESS;
490+
bool use_request_model_name = false;
491+
ObAIFuncExprInfo *ai_func_info = nullptr;
489492
omt::ObTenantAiService *ai_service = MTL(omt::ObTenantAiService *);
490493
ObHybridVectorRefreshTaskCtx *task_ctx = static_cast<ObHybridVectorRefreshTaskCtx *>(get_task_ctx());
491494
if (OB_ISNULL(ai_service) || OB_ISNULL(task_ctx)) {
@@ -495,6 +498,16 @@ int ObHybridVectorRefreshTask::init_endpoint(ObPluginVectorIndexAdaptor &adaptor
495498
LOG_WARN("failed to get ai service guard", K(ret), KPC(task_ctx));
496499
} else if (OB_FAIL(task_ctx->ai_service_.get_ai_endpoint_by_ai_model_name(adaptor.get_endpoint(), task_ctx->endpoint_, false /*need_check*/))) {
497500
LOG_WARN("failed to get endpoint info", K(ret), K(adaptor));
501+
} else if (OB_FALSE_IT(use_request_model_name = !task_ctx->endpoint_->get_request_model_name().empty())) {
502+
} else if (use_request_model_name && OB_FAIL(ob_write_string(task_ctx->allocator_, task_ctx->endpoint_->get_request_model_name(), task_ctx->request_model_name_))) {
503+
LOG_WARN("failed to copy request_model_name", K(ret));
504+
} else if (!use_request_model_name && OB_FAIL(ObAIFuncUtils::get_ai_func_info(task_ctx->allocator_, adaptor.get_endpoint(), ai_func_info))) {
505+
LOG_WARN("failed to get ai func info", K(ret), K(adaptor.get_endpoint()));
506+
} else if (!use_request_model_name && OB_ISNULL(ai_func_info)) {
507+
ret = OB_ERR_UNEXPECTED;
508+
LOG_WARN("ai func info is null", K(ret));
509+
} else if (!use_request_model_name && OB_FAIL(ob_write_string(task_ctx->allocator_, ai_func_info->model_, task_ctx->request_model_name_))) {
510+
LOG_WARN("failed to copy model_name from ai func info", K(ret));
498511
}
499512
return ret;
500513
}
@@ -544,7 +557,7 @@ int ObHybridVectorRefreshTask::prepare_for_embedding(ObPluginVectorIndexAdaptor
544557
} else if (FALSE_IT(table_param = new(table_param)schema::ObTableParam(task_ctx->allocator_))) {
545558
} else if (FALSE_IT(ctx_->task_status_.target_scn_.convert_from_ts(ObTimeUtility::current_time()))) {
546559
} else if (OB_FAIL(ObPluginVectorIndexUtils::read_local_tablet(ls_id_,
547-
&adaptor,
560+
&adaptor,
548561
ctx_->task_status_.target_scn_,
549562
INDEX_TYPE_VEC_DELTA_BUFFER_LOCAL,
550563
task_ctx->allocator_,
@@ -570,7 +583,7 @@ int ObHybridVectorRefreshTask::prepare_for_embedding(ObPluginVectorIndexAdaptor
570583
task_ctx->batch_cnt_ = MAX(task_ctx->batch_cnt_ / 4, ObHybridVectorRefreshTaskCtx::MIN_BATCH_CNT);
571584
}
572585
}
573-
586+
574587
int cur_row_count = 0;
575588
ObSEArray<ObString, 4> chunk_array;
576589
ObSEArray<ObString, 4> tmp_chunk_array;
@@ -663,7 +676,7 @@ int ObHybridVectorRefreshTask::prepare_for_embedding(ObPluginVectorIndexAdaptor
663676
LOG_WARN("failed to get access key", K(ret));
664677
} else if (OB_FAIL(ob_write_string(task_ctx->allocator_, endpoint->get_url(), url, true))) {
665678
LOG_WARN("fail to write string", K(ret));
666-
} else if (OB_FAIL(task_ctx->embedding_task_->init(url, endpoint->get_request_model_name(),
679+
} else if (OB_FAIL(task_ctx->embedding_task_->init(url, task_ctx->request_model_name_,
667680
endpoint->get_provider(), access_key, chunk_array, dim, http_timeout_us, http_max_retries))) {
668681
LOG_WARN("failed to init embedding task", K(ret), KPC(endpoint));
669682
} else {
@@ -703,10 +716,10 @@ int ObHybridVectorRefreshTask::check_embedding_finish(bool &finish)
703716
}
704717

705718
int ObHybridVectorRefreshTask::do_refresh_only(
706-
ObPluginVectorIndexAdaptor &adaptor,
707-
transaction::ObTxDesc *tx_desc,
708-
oceanbase::transaction::ObTxReadSnapshot &snapshot,
709-
storage::ObStoreCtxGuard &store_ctx_guard,
719+
ObPluginVectorIndexAdaptor &adaptor,
720+
transaction::ObTxDesc *tx_desc,
721+
oceanbase::transaction::ObTxReadSnapshot &snapshot,
722+
storage::ObStoreCtxGuard &store_ctx_guard,
710723
storage::ObValueRowIterator &index_id_iter,
711724
storage::ObValueRowIterator &delta_delete_iter)
712725
{
@@ -734,7 +747,7 @@ int ObHybridVectorRefreshTask::do_refresh_only(
734747
LOG_WARN("failed to insert rows to index id table", K(ret), K(adaptor.get_vbitmap_table_id()));
735748
}
736749
store_ctx_guard.reset();
737-
750+
738751
// delete from 3 table.
739752
affected_rows = 0;
740753
if (OB_FAIL(ret)) {
@@ -843,7 +856,7 @@ int ObHybridVectorRefreshTask::delete_embedded_table(ObPluginVectorIndexAdaptor
843856
common::ObNewRowIterator *scan_iter = nullptr;
844857
ObStorageDatumUtils util;
845858
ObArenaAllocator scan_allocator("VecEmbedding", OB_MALLOC_NORMAL_BLOCK_SIZE, MTL_ID());
846-
if (OB_FAIL(ObPluginVectorIndexUtils::read_local_tablet(ls_id_,
859+
if (OB_FAIL(ObPluginVectorIndexUtils::read_local_tablet(ls_id_,
847860
&adaptor,
848861
snapshot.version(),
849862
INDEX_TYPE_HYBRID_INDEX_EMBEDDED_LOCAL,
@@ -975,7 +988,7 @@ int ObHybridVectorRefreshTask::after_embedding(ObPluginVectorIndexAdaptor &adapt
975988
int64_t loop_cnt = 0;
976989
if (OB_FAIL(new_row.init(task_ctx->embedded_table_column_ids_.count()))) {
977990
LOG_WARN("fail to init datum row", K(ret), K(task_ctx->embedded_table_column_ids_), K(new_row));
978-
} else if (adaptor.get_is_need_vid() && OB_FAIL(ObPluginVectorIndexUtils::read_local_tablet(ls_id_,
991+
} else if (adaptor.get_is_need_vid() && OB_FAIL(ObPluginVectorIndexUtils::read_local_tablet(ls_id_,
979992
&adaptor,
980993
ctx_->task_status_.target_scn_,
981994
INDEX_TYPE_VEC_VID_ROWKEY_LOCAL,
@@ -1053,7 +1066,7 @@ int ObHybridVectorRefreshTask::after_embedding(ObPluginVectorIndexAdaptor &adapt
10531066
ObTableScanIterator *embedded_table_scan_iter = nullptr;
10541067
ObArenaAllocator embedde_scan_allocator("VecEmbedding", OB_MALLOC_NORMAL_BLOCK_SIZE, MTL_ID());
10551068
ObRowkey rowkey(obj_ptr, embedded_rowkey_count);
1056-
if (OB_FAIL(ObPluginVectorIndexUtils::read_local_tablet(ls_id_,
1069+
if (OB_FAIL(ObPluginVectorIndexUtils::read_local_tablet(ls_id_,
10571070
&adaptor,
10581071
snapshot.version(),
10591072
INDEX_TYPE_HYBRID_INDEX_EMBEDDED_LOCAL,
@@ -1098,7 +1111,7 @@ int ObHybridVectorRefreshTask::after_embedding(ObPluginVectorIndexAdaptor &adapt
10981111
}
10991112
}
11001113
}
1101-
1114+
11021115
CHECK_TASK_CANCELLED_IN_PROCESS(ret, loop_cnt, ctx_);
11031116
}
11041117
if (OB_NOT_NULL(tsc_service)) {

src/share/vector_index/ob_hybrid_vector_refresh_task.h

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
#ifndef OCEANBASE_OBSERVER_TABLE_OB_HYBRID_VECTOR_REFRESH_TASK_H_
1818
#define OCEANBASE_OBSERVER_TABLE_OB_HYBRID_VECTOR_REFRESH_TASK_H_
19-
19+
2020
#include "lib/string/ob_string.h"
2121
#include "share/scn.h"
2222
#include "lib/thread/thread_mgr_interface.h"
@@ -29,7 +29,7 @@
2929
#include "share/ai_service/ob_ai_service_struct.h"
3030
#include "storage/ob_value_row_iterator.h"
3131
#include "src/observer/omt/ob_tenant_ai_service.h"
32-
32+
3333
namespace oceanbase
3434
{
3535
namespace share
@@ -44,14 +44,14 @@ enum ObHybridVectorRefreshTaskStatus
4444
TASK_FINISH = 4,
4545
};
4646

47-
class ObVecEmbeddingAsyncTaskExecutor final : public ObVecAsyncTaskExector
47+
class ObVecEmbeddingAsyncTaskExecutor final : public ObVecAsyncTaskExector
4848
{
49-
public:
49+
public:
5050
ObVecEmbeddingAsyncTaskExecutor() : ObVecAsyncTaskExector()
5151
{}
5252
virtual ~ObVecEmbeddingAsyncTaskExecutor() {}
5353
virtual int load_task(uint64_t &task_trace_base_num) override;
54-
private:
54+
private:
5555
bool check_operation_allow() override;
5656
};
5757

@@ -104,6 +104,7 @@ struct ObHybridVectorRefreshTaskCtx : public ObVecIndexAsyncTaskCtx
104104
ObSEArray<uint64_t, 4> embedded_table_update_ids_;
105105
omt::ObAiServiceGuard ai_service_;
106106
const ObAiModelEndpointInfo *endpoint_;
107+
ObString request_model_name_;
107108
ObPluginVectorIndexAdapterGuard adp_guard_;
108109
bool task_started_;
109110
uint32_t part_key_num_; // is part key but rowkey
@@ -123,7 +124,7 @@ class ObHybridVectorRefreshTask : public ObVecIndexIAsyncTask
123124
}
124125
}
125126
virtual int do_work() override;
126-
ObHybridVectorRefreshTaskStatus current_status() {
127+
ObHybridVectorRefreshTaskStatus current_status() {
127128
ObHybridVectorRefreshTaskStatus status = ObHybridVectorRefreshTaskStatus::INVALID_TASK_STATUS;
128129
ObHybridVectorRefreshTaskCtx *ctx = static_cast<ObHybridVectorRefreshTaskCtx *>(get_task_ctx());
129130
if (OB_NOT_NULL(ctx)) {
@@ -146,21 +147,21 @@ class ObHybridVectorRefreshTask : public ObVecIndexIAsyncTask
146147
int get_index_id_column_ids(ObPluginVectorIndexAdaptor &adaptor);
147148
int get_embedded_table_column_ids(ObPluginVectorIndexAdaptor &adaptor);
148149
int init_dml_param(
149-
uint64_t table_id,
150-
ObDMLBaseParam &dml_param,
151-
share::schema::ObTableDMLParam &table_param,
152-
ObIArray<uint64_t> &dml_column_ids,
153-
transaction::ObTxDesc *tx_desc,
154-
oceanbase::transaction::ObTxReadSnapshot &snapshot,
150+
uint64_t table_id,
151+
ObDMLBaseParam &dml_param,
152+
share::schema::ObTableDMLParam &table_param,
153+
ObIArray<uint64_t> &dml_column_ids,
154+
transaction::ObTxDesc *tx_desc,
155+
oceanbase::transaction::ObTxReadSnapshot &snapshot,
155156
storage::ObStoreCtxGuard &store_ctx_guard);
156157
int init_endpoint(ObPluginVectorIndexAdaptor &adaptor);
157158
int prepare_for_embedding(ObPluginVectorIndexAdaptor &adaptor);
158159
int prepare_index_id_data(storage::ObValueRowIterator &index_id_iter, storage::ObValueRowIterator &delta_delete_iter);
159160
int do_refresh_only(
160-
ObPluginVectorIndexAdaptor &adaptor,
161-
transaction::ObTxDesc *tx_desc,
162-
oceanbase::transaction::ObTxReadSnapshot &snapshot,
163-
storage::ObStoreCtxGuard &store_ctx_guard,
161+
ObPluginVectorIndexAdaptor &adaptor,
162+
transaction::ObTxDesc *tx_desc,
163+
oceanbase::transaction::ObTxReadSnapshot &snapshot,
164+
storage::ObStoreCtxGuard &store_ctx_guard,
164165
storage::ObValueRowIterator &index_id_iter,
165166
storage::ObValueRowIterator &delta_delete_iter);
166167
int delete_embedded_table(ObPluginVectorIndexAdaptor &adaptor, transaction::ObTxDesc *tx_desc, oceanbase::transaction::ObTxReadSnapshot &snapshot, storage::ObStoreCtxGuard &store_ctx_guard);

src/storage/ddl/ob_hnsw_embedmgr.cpp

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ bool ObEmbeddingConfig::is_valid() const
3535
}
3636
return is_valid;
3737
}
38-
void ObEmbeddingConfig::set_config(const ObString &model_url, const ObString &model_name,
38+
void ObEmbeddingConfig::set_config(const ObString &model_url, const ObString &model_name,
3939
const ObString &user_key, const ObString &provider)
4040
{
4141
model_url_ = model_url;
@@ -210,7 +210,7 @@ int ObTaskBatchInfo::init(const int64_t batch_size, const int64_t vec_dim)
210210
batch_size_ = batch_size;
211211
vec_dim_ = vec_dim;
212212
current_count_ = 0;
213-
213+
214214
if (OB_FAIL(results_.reserve(batch_size))) {
215215
LOG_WARN("reserve results array failed", K(ret), K(batch_size));
216216
} else {
@@ -255,7 +255,7 @@ int ObTaskBatchInfo::add_item(const blocksstable::ObStorageDatum &text,
255255
need_embedding_count_++;
256256
}
257257
}
258-
258+
259259
// pre-allocate space (will be filled by embedding task)
260260
if (OB_SUCC(ret)) {
261261
float *vec_buf = static_cast<float*>(allocator_.alloc(vec_dim_ * sizeof(float)));
@@ -309,7 +309,7 @@ int ObTaskSlotRing::init(const int64_t capacity)
309309
{
310310
int ret = OB_SUCCESS;
311311
ObSpinLockGuard guard(lock_);
312-
312+
313313
if (capacity <= 0) {
314314
ret = OB_INVALID_ARGUMENT;
315315
LOG_WARN("invalid capacity", K(ret), K(capacity));
@@ -359,7 +359,7 @@ int ObTaskSlotRing::mark_ready(const int64_t slot_idx, const int ret_code)
359359
{
360360
int ret = OB_SUCCESS;
361361
ObSpinLockGuard guard(lock_);
362-
362+
363363
if (slot_idx < 0 || slot_idx >= slots_.count()) {
364364
ret = OB_INVALID_ARGUMENT;
365365
LOG_WARN("invalid slot idx", K(ret), K(slot_idx), K(slots_.count()));
@@ -375,7 +375,7 @@ int ObTaskSlotRing::pop_ready_in_order(ObTaskBatchInfo *&batch_info, int &ret_co
375375
int ret = OB_SUCCESS;
376376
batch_info = nullptr;
377377
ObSpinLockGuard guard(lock_);
378-
378+
379379
if (head_idx_ != next_idx_ && slots_.at(head_idx_).ready_) {
380380
Slot &slot = slots_.at(head_idx_);
381381
if (!slot.ready_) {
@@ -396,7 +396,7 @@ int ObTaskSlotRing::pop_ready_in_order(ObTaskBatchInfo *&batch_info, int &ret_co
396396
slot.task_->release_if_managed();
397397
slot.task_ = nullptr;
398398
}
399-
399+
400400
if (OB_SUCC(ret) || OB_NOT_NULL(batch_info)) {
401401
slot.ready_ = false;
402402
head_idx_ = (head_idx_ + 1) % slots_.count();
@@ -462,7 +462,7 @@ int ObTaskSlotRing::wait_for_head_completion()
462462
{
463463
int ret = OB_SUCCESS;
464464
share::ObEmbeddingTask *task_to_wait = nullptr;
465-
465+
466466
{
467467
ObSpinLockGuard guard(lock_);
468468
if (head_idx_ != next_idx_) {
@@ -472,7 +472,7 @@ int ObTaskSlotRing::wait_for_head_completion()
472472
}
473473
}
474474
}
475-
475+
476476
if (OB_NOT_NULL(task_to_wait)) {
477477
if (OB_FAIL(task_to_wait->wait_for_completion())) {
478478
LOG_WARN("wait for head embedding task completion failed", K(ret));
@@ -643,7 +643,7 @@ int ObEmbeddingTaskMgr::submit_batch_info(ObTaskBatchInfo *&batch_info)
643643
}
644644
}
645645
}
646-
646+
647647
if (OB_SUCC(ret) && embedding_count > 0) {
648648
// Only create embedding task if there are items to embed
649649
void *cb_buf = ob_malloc(sizeof(ObEmbeddingIOCallback), ObMemAttr(MTL_ID(), "EmbedCb"));
@@ -654,7 +654,7 @@ int ObEmbeddingTaskMgr::submit_batch_info(ObTaskBatchInfo *&batch_info)
654654
ObEmbeddingIOCallback *cb = new (cb_buf) ObEmbeddingIOCallback();
655655
ObEmbeddingIOCallbackHandle *cb_handle = nullptr;
656656
share::ObEmbeddingTask *task = nullptr;
657-
657+
658658
if (OB_ISNULL(cb_handle = ObEmbeddingIOCallbackHandle::create(cb))) {
659659
ret = OB_ALLOCATE_MEMORY_FAILED;
660660
LOG_WARN("create callback handle failed", K(ret));
@@ -666,13 +666,13 @@ int ObEmbeddingTaskMgr::submit_batch_info(ObTaskBatchInfo *&batch_info)
666666
} else {
667667
task = new (task_mem) share::ObEmbeddingTask();
668668
const int64_t vec_dim = results.at(0)->get_vector_dim();
669-
if (OB_FAIL(task->init(cfg_.model_url_, cfg_.model_name_, cfg_.provider_,
669+
if (OB_FAIL(task->init(cfg_.model_url_, cfg_.model_name_, cfg_.provider_,
670670
cfg_.user_key_, texts, vec_dim, model_request_timeout_us_, model_max_retries_, cb_handle))) {
671671
LOG_WARN("failed to initialize EmbeddingTask", K(ret));
672672
}
673673
}
674674
}
675-
675+
676676
if (OB_SUCC(ret) && OB_NOT_NULL(task)) {
677677
if (OB_FAIL(cb->init(this, slot_idx, batch_info, task, results.at(0)->get_vector_dim()))) {
678678
LOG_WARN("init callback failed", K(ret));
@@ -684,7 +684,7 @@ int ObEmbeddingTaskMgr::submit_batch_info(ObTaskBatchInfo *&batch_info)
684684
}
685685
}
686686
}
687-
687+
688688
if (OB_SUCC(ret)) {
689689
slot_ring_.set_task(slot_idx, task);
690690
slot_ring_.set_batch_info(slot_idx, batch_info); // Take ownership
@@ -750,6 +750,7 @@ int ObEmbeddingTaskMgr::get_ai_config(const common::ObString &model_id)
750750
const share::ObAiModelEndpointInfo *endpoint_info = nullptr;
751751
omt::ObAiServiceGuard ai_service_guard;
752752
omt::ObTenantAiService *ai_service = MTL(omt::ObTenantAiService*);
753+
bool use_request_model_name = false;
753754
if (OB_FAIL(ObAIFuncUtils::get_ai_func_info(allocator_, const_cast<common::ObString&>(model_id), info))) {
754755
LOG_WARN("failed to get ai func info", K(ret), K(model_id));
755756
} else if (OB_ISNULL(info)) {
@@ -762,9 +763,12 @@ int ObEmbeddingTaskMgr::get_ai_config(const common::ObString &model_id)
762763
LOG_WARN("failed to get ai service guard", K(ret));
763764
} else if (OB_FAIL(ai_service_guard.get_ai_endpoint_by_ai_model_name(model_id, endpoint_info))) {
764765
LOG_WARN("failed to get endpoint info", K(ret), K(model_id));
766+
} else if (OB_FALSE_IT(use_request_model_name = !endpoint_info->get_request_model_name().empty())) {
765767
} else if (OB_FAIL(ob_write_string(allocator_, endpoint_info->get_url(), cfg_.model_url_))) {
766768
LOG_WARN("failed to copy model_url", K(ret));
767-
} else if (OB_FAIL(ob_write_string(allocator_, info->model_, cfg_.model_name_))) {
769+
} else if (use_request_model_name && OB_FAIL(ob_write_string(allocator_, endpoint_info->get_request_model_name(), cfg_.model_name_))) {
770+
LOG_WARN("failed to copy model_name", K(ret));
771+
} else if (!use_request_model_name && OB_FAIL(ob_write_string(allocator_, info->model_, cfg_.model_name_))) {
768772
LOG_WARN("failed to copy model_name", K(ret));
769773
} else if (OB_FAIL(endpoint_info->get_unencrypted_access_key(allocator_, cfg_.user_key_))) {
770774
LOG_WARN("failed to copy user_key", K(ret));

0 commit comments

Comments
 (0)