@@ -160,7 +160,6 @@ def retrieve( # type: ignore [override]
160160 List of dictionaries containing document metadata and text.
161161 Each dictionary has keys "text", "id", and "title".
162162 """
163-
164163 batch_size = question_hidden_states .shape [0 ]
165164
166165 # Convert the question hidden states into a list of 1D query vectors.
@@ -284,30 +283,27 @@ def generate_answer(
284283 raise ValueError (
285284 "`question_encoder` and `generator_model` must be provided to use `generate_answer`."
286285 )
287-
288- # Convert query to hidden states format expected by retrieve
289- inputs = self .question_encoder_tokenizer (
290- query , return_tensors = "pt" , padding = True , truncation = True
291- ).to (self .question_encoder .device ) # type: ignore
292-
293- question_hidden_states = self .question_encoder (** inputs ).last_hidden_state
294-
295- # Get documents using retrieve method
296- _ , _ , doc_dicts = self .retrieve (
297- question_hidden_states , n_docs = top_k , query = query
286+ torch = get_torch ()
287+ inputs = self .question_encoder_tokenizer (query , return_tensors = "pt" ).to (
288+ self .question_encoder .device
289+ )
290+ question_embeddings = self .question_encoder (** inputs ).pooler_output
291+ question_embeddings = (
292+ question_embeddings .detach ().cpu ().to (torch .float32 ).numpy ()
298293 )
294+ _ , _ , doc_batch = self .retrieve (question_embeddings , n_docs = top_k , query = query )
299295
300- contexts = doc_dicts [0 ]["text" ] if doc_dicts else []
296+ contexts = doc_batch [0 ]["text" ] if doc_batch else []
301297 context_str = "\n \n " .join (filter (None , contexts ))
302298
303299 prompt = f"Context: { context_str } \n \n Question: { query } \n \n Answer:"
304300
305301 generator_inputs = self .generator_tokenizer (prompt , return_tensors = "pt" ).to (
306302 self .generator_model .device
307- ) # type: ignore
303+ )
308304 output_ids = self .generator_model .generate (
309305 ** generator_inputs , max_new_tokens = max_new_tokens
310- ) # type: ignore
306+ )
311307
312308 return self .generator_tokenizer .decode (output_ids [0 ], skip_special_tokens = True )
313309
0 commit comments