Skip to content

Commit 70656cd

Browse files
JoanFMgithub-actions[bot]
authored andcommitted
[MOD-12647] fix: handle the case in Coordinator when SCORE is sent alone without extra fields. (#7492)
* fix: Do not send score alone if Expired flag is set * test: add test_expire_ft_hybrid * improve test * do not send NULL, simply avoid serialize result if expired * improving test * handle more complete hybrid expire test * small change in test * revert changes in hybrid exec * remove assertion that cannot be guaranteed, and handle no extra_attributes in response from shard * handle both resp protocols (cherry picked from commit b916f07)
1 parent a4993de commit 70656cd

3 files changed

Lines changed: 90 additions & 10 deletions

File tree

src/coord/rpnet.c

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -347,15 +347,18 @@ int rpnetNext(ResultProcessor *self, SearchResult *r) {
347347

348348
MRReply *score = NULL;
349349
MRReply *fields = MRReply_ArrayElement(rows, nc->curIdx++);
350+
bool has_fields = false;
350351
if (resp3) {
351352
RS_LOG_ASSERT(fields && MRReply_Type(fields) == MR_REPLY_MAP, "invalid result record");
352353
// extract score if it exists, WITHSCORES was specified
353354
score = MRReply_MapElement(fields, "score");
354355
fields = MRReply_MapElement(fields, "extra_attributes");
355-
RS_LOG_ASSERT(fields && MRReply_Type(fields) == MR_REPLY_MAP, "invalid fields record");
356+
// It could happen if Result_ExpiredDoc is set by the Loader on the shard, that no extra attributes is returned. In that case
357+
// we do not have keys to return.
358+
has_fields = fields && MRReply_Type(fields) == MR_REPLY_MAP;
356359
} else {
357-
RS_LOG_ASSERT(fields && MRReply_Type(fields) == MR_REPLY_ARRAY, "invalid result record");
358-
RS_LOG_ASSERT(MRReply_Length(fields) % 2 == 0, "invalid fields record");
360+
has_fields = fields && MRReply_Type(fields) == MR_REPLY_ARRAY;
361+
RS_LOG_ASSERT(!has_fields || has_fields && MRReply_Length(fields) % 2 == 0, "invalid fields record");
359362
}
360363

361364
// The score is optional, in hybrid we need the score for the sorter and hybrid merger
@@ -366,12 +369,14 @@ int rpnetNext(ResultProcessor *self, SearchResult *r) {
366369
SearchResult_SetScore(r, MRReply_Double(score));
367370
}
368371

369-
for (size_t i = 0; i < MRReply_Length(fields); i += 2) {
370-
size_t len;
371-
const char *field = MRReply_String(MRReply_ArrayElement(fields, i), &len);
372-
MRReply *val = MRReply_ArrayElement(fields, i + 1);
373-
RSValue *v = MRReply_ToValue(val);
374-
RLookup_WriteOwnKeyByName(nc->lookup, field, len, SearchResult_GetRowDataMut(r), v);
372+
if (has_fields) {
373+
for (size_t i = 0; i < MRReply_Length(fields); i += 2) {
374+
size_t len;
375+
const char *field = MRReply_String(MRReply_ArrayElement(fields, i), &len);
376+
MRReply *val = MRReply_ArrayElement(fields, i + 1);
377+
RSValue *v = MRReply_ToValue(val);
378+
RLookup_WriteOwnKeyByName(nc->lookup, field, len, SearchResult_GetRowDataMut(r), v);
379+
}
375380
}
376381
return RS_RESULT_OK;
377382
}

tests/pytests/common.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1001,7 +1001,16 @@ def get_results_from_hybrid_response(response) -> Dict[str, Dict[str, any]]:
10011001
Dict mapping key -> dict of all fields from the results list
10021002
Example: {'doc:1': {'__score': '0.5', 'vector_distance': '0.3'}}
10031003
"""
1004-
# return dict mapping key -> all fields from the results list
1004+
# Handle RESP3 format (dict)
1005+
if isinstance(response, dict):
1006+
results = {}
1007+
for result in response.get('results', []):
1008+
if '__key' in result:
1009+
key = result['__key']
1010+
results[key] = result
1011+
total_results = response.get('total_results', 0)
1012+
return results, total_results
1013+
10051014
res_results_index = recursive_index(response, 'results')
10061015
res_count_index = recursive_index(response, 'total_results')
10071016
res_results_index[-1] += 1

tests/pytests/test_expire.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,72 @@ def test_expire_aggregate(env):
253253
# The result count is not accurate in aggregation, for now we compare res to the expected results with the wrong count
254254
env.assertEqual(res, [1, ['t', 'arr'], ['t', 'bar']])
255255

256+
257+
def expire_ft_hybrid_test(protocol):
258+
env = Env(protocol=protocol)
259+
# Use "lazy" expire (expire only when key is accessed) on all shards
260+
env.cmd('DEBUG', 'SET-ACTIVE-EXPIRE', '0')
261+
262+
# Create index with text, vector, and numeric fields
263+
env.expect('FT.CREATE', 'idx', 'SCHEMA', 't', 'TEXT', 'n', 'NUMERIC', 'v', 'VECTOR', 'FLAT', '6', 'TYPE', 'FLOAT32', 'DIM', '2', 'DISTANCE_METRIC', 'L2').ok()
264+
265+
# Create test vectors (2-dimensional float32)
266+
import numpy as np
267+
query_vector = np.array([0.5, 0.5]).astype(np.float32).tobytes()
268+
269+
# Use cluster-aware connection for data insertion
270+
with env.getClusterConnectionIfNeeded() as conn:
271+
# Create 1000 documents
272+
for i in range(1000):
273+
# Create a unique vector for each document
274+
vector = np.array([float(i % 100) / 100.0, float((i + 1) % 100) / 100.0]).astype(np.float32).tobytes()
275+
doc_key = f'doc{i}'
276+
text_value = f'text{i}'
277+
numeric_value = str(i)
278+
279+
conn.execute_command('HSET', doc_key, 't', text_value, 'n', numeric_value, 'v', vector)
280+
281+
# Expire the first 990 documents (doc0 to doc989)
282+
if i < 990:
283+
conn.execute_command('PEXPIRE', doc_key, 1)
284+
285+
# Ensure expiration before query
286+
time.sleep(0.01)
287+
288+
# Test FT.HYBRID requesting 1000 results but expecting only 10 (non-expired documents)
289+
hybrid_query = ['FT.HYBRID', 'idx', 'SEARCH', '*', 'VSIM', '@v', query_vector, 'LIMIT', '0', '1000', 'COMBINE', 'RRF', '2', 'CONSTANT', '60', 'LOAD', '4', '@__key', '@__score', '@t', '@n']
290+
291+
# Execute query using cluster-aware command to get expected results
292+
actual_res = env.cmd(*hybrid_query)
293+
from common import get_results_from_hybrid_response
294+
actual_results_dict, actual_total_results = get_results_from_hybrid_response(actual_res)
295+
296+
# Validate that only 10 documents are returned (doc990 to doc999)
297+
env.assertEqual(actual_total_results, 10)
298+
299+
# Verify that only non-expired documents are present
300+
expected_doc_keys = {f'doc{i}' for i in range(990, 1000)}
301+
actual_doc_keys = set(actual_results_dict.keys())
302+
env.assertEqual(actual_doc_keys, expected_doc_keys)
303+
304+
# Verify that each returned document has the correct attributes
305+
for doc_key in actual_results_dict:
306+
doc_num = int(doc_key[3:]) # Extract number from 'docXXX'
307+
env.assertTrue('__key' in actual_results_dict[doc_key])
308+
env.assertTrue('__score' in actual_results_dict[doc_key])
309+
env.assertTrue('t' in actual_results_dict[doc_key])
310+
env.assertTrue('n' in actual_results_dict[doc_key])
311+
env.assertEqual(actual_results_dict[doc_key]['__key'], doc_key)
312+
env.assertEqual(actual_results_dict[doc_key]['t'], f'text{doc_num}')
313+
env.assertEqual(actual_results_dict[doc_key]['n'], str(doc_num))
314+
env.assertTrue(float(actual_results_dict[doc_key]['__score']) >= 0)
315+
316+
def test_expire_ft_hybrid_resp2():
317+
expire_ft_hybrid_test(protocol=2)
318+
319+
def test_expire_ft_hybrid_resp3():
320+
expire_ft_hybrid_test(protocol=3)
321+
256322
def createTextualSchema(field_to_additional_schema_keywords):
257323
schema = []
258324
for field, additional_schema_words in field_to_additional_schema_keywords.items():

0 commit comments

Comments
 (0)