Skip to content

Commit dc5b2af

Browse files
Fiona-Watersntkathole
authored andcommitted
fix: Update generate_answer function to provide correct parameter format to retrieve function
Signed-off-by: Fiona Waters <fiwaters6@gmail.com> Co-authored by: Esa Fazal <efazal@redhat.com>
1 parent 40d25c6 commit dc5b2af

File tree

1 file changed

+11
-15
lines changed

1 file changed

+11
-15
lines changed

sdk/python/feast/rag_retriever.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,6 @@ def retrieve( # type: ignore [override]
160160
List of dictionaries containing document metadata and text.
161161
Each dictionary has keys "text", "id", and "title".
162162
"""
163-
164163
batch_size = question_hidden_states.shape[0]
165164

166165
# Convert the question hidden states into a list of 1D query vectors.
@@ -284,30 +283,27 @@ def generate_answer(
284283
raise ValueError(
285284
"`question_encoder` and `generator_model` must be provided to use `generate_answer`."
286285
)
287-
288-
# Convert query to hidden states format expected by retrieve
289-
inputs = self.question_encoder_tokenizer(
290-
query, return_tensors="pt", padding=True, truncation=True
291-
).to(self.question_encoder.device) # type: ignore
292-
293-
question_hidden_states = self.question_encoder(**inputs).last_hidden_state
294-
295-
# Get documents using retrieve method
296-
_, _, doc_dicts = self.retrieve(
297-
question_hidden_states, n_docs=top_k, query=query
286+
torch = get_torch()
287+
inputs = self.question_encoder_tokenizer(query, return_tensors="pt").to(
288+
self.question_encoder.device
289+
)
290+
question_embeddings = self.question_encoder(**inputs).pooler_output
291+
question_embeddings = (
292+
question_embeddings.detach().cpu().to(torch.float32).numpy()
298293
)
294+
_, _, doc_batch = self.retrieve(question_embeddings, n_docs=top_k, query=query)
299295

300-
contexts = doc_dicts[0]["text"] if doc_dicts else []
296+
contexts = doc_batch[0]["text"] if doc_batch else []
301297
context_str = "\n\n".join(filter(None, contexts))
302298

303299
prompt = f"Context: {context_str}\n\nQuestion: {query}\n\nAnswer:"
304300

305301
generator_inputs = self.generator_tokenizer(prompt, return_tensors="pt").to(
306302
self.generator_model.device
307-
) # type: ignore
303+
)
308304
output_ids = self.generator_model.generate(
309305
**generator_inputs, max_new_tokens=max_new_tokens
310-
) # type: ignore
306+
)
311307

312308
return self.generator_tokenizer.decode(output_ids[0], skip_special_tokens=True)
313309

0 commit comments

Comments
 (0)