-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Expand file tree
/
Copy pathlocal_eval_service.py
More file actions
537 lines (477 loc) · 19.9 KB
/
local_eval_service.py
File metadata and controls
537 lines (477 loc) · 19.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import inspect
import logging
from typing import AsyncGenerator
from typing import Callable
from typing import Optional
import uuid
from typing_extensions import override
from ..agents.base_agent import BaseAgent
from ..artifacts.base_artifact_service import BaseArtifactService
from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
from ..errors.not_found_error import NotFoundError
from ..memory.base_memory_service import BaseMemoryService
from ..sessions.base_session_service import BaseSessionService
from ..sessions.in_memory_session_service import InMemorySessionService
from ..utils._client_labels_utils import client_label_context
from ..utils._client_labels_utils import EVAL_CLIENT_LABEL
from ..utils.feature_decorator import experimental
from .base_eval_service import BaseEvalService
from .base_eval_service import EvaluateConfig
from .base_eval_service import EvaluateRequest
from .base_eval_service import InferenceRequest
from .base_eval_service import InferenceResult
from .base_eval_service import InferenceStatus
from .eval_case import ConversationScenario
from .eval_case import Invocation
from .eval_metrics import EvalMetric
from .eval_metrics import EvalMetricResult
from .eval_metrics import EvalMetricResultDetails
from .eval_metrics import EvalMetricResultPerInvocation
from .eval_metrics import Rubric
from .eval_result import EvalCaseResult
from .eval_set import EvalCase
from .eval_set_results_manager import EvalSetResultsManager
from .eval_sets_manager import EvalSetsManager
from .evaluation_generator import EvaluationGenerator
from .evaluator import EvalStatus
from .evaluator import EvaluationResult
from .evaluator import PerInvocationResult
from .metric_evaluator_registry import DEFAULT_METRIC_EVALUATOR_REGISTRY
from .metric_evaluator_registry import MetricEvaluatorRegistry
from .simulation.user_simulator_provider import UserSimulatorProvider
logger = logging.getLogger('google_adk.' + __name__)
EVAL_SESSION_ID_PREFIX = '___eval___session___'
def _get_session_id() -> str:
return f'{EVAL_SESSION_ID_PREFIX}{str(uuid.uuid4())}'
def _add_rubrics_to_invocation(
invocation: Invocation, rubrics_to_add: list[Rubric]
):
"""Adds rubrics to invocation, throwing ValueError on duplicate rubric_id."""
if not invocation.rubrics:
invocation.rubrics = []
existing_ids = {r.rubric_id for r in invocation.rubrics}
for rubric in rubrics_to_add:
if rubric.rubric_id in existing_ids:
raise ValueError(
f"Rubric with rubric_id '{rubric.rubric_id}' already exists."
)
invocation.rubrics.append(rubric)
existing_ids.add(rubric.rubric_id)
def _copy_eval_case_rubrics_to_actual_invocations(
eval_case: EvalCase, actual_invocations: list[Invocation]
):
"""Copies EvalCase level rubrics to all actual invocations."""
if hasattr(eval_case, 'rubrics') and eval_case.rubrics:
for invocation in actual_invocations:
_add_rubrics_to_invocation(invocation, eval_case.rubrics)
def _copy_invocation_rubrics_to_actual_invocations(
expected_invocations: Optional[list[Invocation]],
actual_invocations: list[Invocation],
):
"""Copies invocation level rubrics to corresponding actual invocations."""
if expected_invocations:
for actual_invocation, expected_invocation in zip(
actual_invocations, expected_invocations
):
if expected_invocation.rubrics:
_add_rubrics_to_invocation(
actual_invocation, expected_invocation.rubrics
)
@experimental
class LocalEvalService(BaseEvalService):
"""An implementation of BaseEvalService, that runs the evals locally."""
def __init__(
self,
root_agent: BaseAgent,
eval_sets_manager: EvalSetsManager,
metric_evaluator_registry: Optional[MetricEvaluatorRegistry] = None,
session_service: Optional[BaseSessionService] = None,
artifact_service: Optional[BaseArtifactService] = None,
eval_set_results_manager: Optional[EvalSetResultsManager] = None,
session_id_supplier: Callable[[], str] = _get_session_id,
user_simulator_provider: UserSimulatorProvider = UserSimulatorProvider(),
memory_service: Optional[BaseMemoryService] = None,
):
self._root_agent = root_agent
self._eval_sets_manager = eval_sets_manager
metric_evaluator_registry = (
metric_evaluator_registry or DEFAULT_METRIC_EVALUATOR_REGISTRY
)
session_service = session_service or InMemorySessionService()
artifact_service = artifact_service or InMemoryArtifactService()
self._metric_evaluator_registry = metric_evaluator_registry
self._session_service = session_service
self._artifact_service = artifact_service
self._eval_set_results_manager = eval_set_results_manager
self._session_id_supplier = session_id_supplier
self._user_simulator_provider = user_simulator_provider
self._memory_service = memory_service
@override
async def perform_inference(
self,
inference_request: InferenceRequest,
) -> AsyncGenerator[InferenceResult, None]:
"""Returns InferenceResult obtained from the Agent as and when they are available.
Args:
inference_request: The request for generating inferences.
"""
# Get the eval set from the storage.
eval_set = self._eval_sets_manager.get_eval_set(
app_name=inference_request.app_name,
eval_set_id=inference_request.eval_set_id,
)
if not eval_set:
raise NotFoundError(
f'Eval set with id {inference_request.eval_set_id} not found for app'
f' {inference_request.app_name}'
)
# Select eval cases for which we need to run inferencing. If the inference
# request specified eval cases, then we use only those.
eval_cases = eval_set.eval_cases
if inference_request.eval_case_ids:
eval_cases = [
eval_case
for eval_case in eval_cases
if eval_case.eval_id in inference_request.eval_case_ids
]
semaphore = asyncio.Semaphore(
value=inference_request.inference_config.parallelism
)
async def run_inference(eval_case):
async with semaphore:
return await self._perform_inference_single_eval_item(
app_name=inference_request.app_name,
eval_set_id=inference_request.eval_set_id,
eval_case=eval_case,
root_agent=self._root_agent,
use_live=inference_request.inference_config.use_live,
live_timeout_seconds=inference_request.inference_config.live_timeout_seconds,
)
inference_results = [run_inference(eval_case) for eval_case in eval_cases]
for inference_result in asyncio.as_completed(inference_results):
yield await inference_result
@override
async def evaluate(
self,
evaluate_request: EvaluateRequest,
) -> AsyncGenerator[EvalCaseResult, None]:
"""Returns EvalCaseResult for each item as and when they are available.
Args:
evaluate_request: The request to perform metric evaluations on the
inferences.
"""
semaphore = asyncio.Semaphore(
value=evaluate_request.evaluate_config.parallelism
)
async def run_evaluation(inference_result):
async with semaphore:
return await self._evaluate_single_inference_result(
inference_result=inference_result,
evaluate_config=evaluate_request.evaluate_config,
)
evaluation_tasks = [
run_evaluation(inference_result)
for inference_result in evaluate_request.inference_results
]
results_by_set = {}
for evaluation_task in asyncio.as_completed(evaluation_tasks):
inference_result, eval_case_result = await evaluation_task
results_by_set.setdefault(inference_result.eval_set_id, []).append(
(inference_result.app_name, eval_case_result)
)
yield eval_case_result
if self._eval_set_results_manager:
for eval_set_id, results in results_by_set.items():
app_name = results[0][0]
cases = [r[1] for r in results]
self._eval_set_results_manager.save_eval_set_result(
app_name=app_name,
eval_set_id=eval_set_id,
eval_case_results=cases,
)
async def _evaluate_single_inference_result(
self, inference_result: InferenceResult, evaluate_config: EvaluateConfig
) -> tuple[InferenceResult, EvalCaseResult]:
"""Returns the inference result and its corresponding EvalCaseResult.
A single inference result can have multiple invocations. For each
invocation, this method evaluates the metrics present in evaluate config.
The EvalCaseResult contains scores for each metric per invocation and the
overall score.
"""
eval_case = self._eval_sets_manager.get_eval_case(
app_name=inference_result.app_name,
eval_set_id=inference_result.eval_set_id,
eval_case_id=inference_result.eval_case_id,
)
if eval_case is None:
raise NotFoundError(
f'Eval case with id {inference_result.eval_case_id} not found for'
f' app {inference_result.app_name} and eval set'
f' {inference_result.eval_set_id}.'
)
# Metric results for each invocation
eval_metric_result_per_invocation = []
# We also keep track of the overall score for a metric, derived from all
# invocation. For example, if we were keeping track the metric that compares
# how well is the final response as compared to a golden answer, then each
# invocation will have the value of this metric. We will also have an
# overall score using aggregation strategy across all invocations. This
# would be the score for the eval case.
overall_eval_metric_results = []
user_id = (
eval_case.session_input.user_id
if eval_case.session_input and eval_case.session_input.user_id
else 'test_user_id'
)
if eval_case.conversation_scenario is None and len(
inference_result.inferences
) != len(eval_case.conversation):
raise ValueError(
'Inferences should match conversations in eval case. Found'
f'{len(inference_result.inferences)} inferences '
f'{len(eval_case.conversation)} conversations in eval cases.'
)
# Pre-creating the EvalMetricResults entries for each invocation.
for idx, actual in enumerate(inference_result.inferences):
eval_metric_result_per_invocation.append(
EvalMetricResultPerInvocation(
actual_invocation=actual,
expected_invocation=eval_case.conversation[idx]
if eval_case.conversation
else None,
# We will fill this as we evaluate each metric per invocation.
eval_metric_results=[],
)
)
actual_invocations = inference_result.inferences
expected_invocations = eval_case.conversation
# 1. Copy EvalCase level rubrics to all actual invocations.
_copy_eval_case_rubrics_to_actual_invocations(eval_case, actual_invocations)
# 2. If expected invocations are present, copy invocation level
# rubrics to corresponding actual invocations.
_copy_invocation_rubrics_to_actual_invocations(
expected_invocations, actual_invocations
)
for eval_metric in evaluate_config.eval_metrics:
# Perform evaluation of the metric.
await self._evaluate_metric_for_eval_case(
eval_metric,
eval_case,
inference_result,
eval_metric_result_per_invocation,
overall_eval_metric_results,
)
final_eval_status = self._generate_final_eval_status(
overall_eval_metric_results
)
eval_case_result = EvalCaseResult(
eval_set_file=inference_result.eval_set_id,
eval_set_id=inference_result.eval_set_id,
eval_id=inference_result.eval_case_id,
final_eval_status=final_eval_status,
overall_eval_metric_results=overall_eval_metric_results,
eval_metric_result_per_invocation=eval_metric_result_per_invocation,
session_id=inference_result.session_id,
session_details=await self._session_service.get_session(
app_name=inference_result.app_name,
user_id=user_id,
session_id=inference_result.session_id,
),
user_id=user_id,
)
return (inference_result, eval_case_result)
async def _evaluate_metric_for_eval_case(
self,
eval_metric: EvalMetric,
eval_case: EvalCase,
inference_result: InferenceResult,
eval_metric_result_per_invocation: list[EvalMetricResultPerInvocation],
overall_eval_metric_results: list[EvalMetricResult],
):
"""Performs evaluation of a metric for a given eval case and inference result."""
try:
with client_label_context(EVAL_CLIENT_LABEL):
evaluation_result = await self._evaluate_metric(
eval_metric=eval_metric,
actual_invocations=inference_result.inferences,
expected_invocations=eval_case.conversation,
conversation_scenario=eval_case.conversation_scenario,
)
except Exception as e:
# We intentionally catch the Exception as we don't want failures to
# affect other metric evaluation.
logger.error(
"Metric evaluation failed for metric `%s` for eval case id '%s'"
' with following error `%s`',
eval_metric.metric_name,
eval_case.eval_id,
e,
exc_info=True,
)
# We use an empty result.
evaluation_result = EvaluationResult(
overall_eval_status=EvalStatus.NOT_EVALUATED
)
# Track overall score across all invocations.
eval_metric_result_details = EvalMetricResultDetails(
rubric_scores=evaluation_result.overall_rubric_scores
)
overall_eval_metric_results.append(
EvalMetricResult(
score=evaluation_result.overall_score,
eval_status=evaluation_result.overall_eval_status,
details=eval_metric_result_details,
**eval_metric.model_dump(),
)
)
if (
evaluation_result.overall_eval_status != EvalStatus.NOT_EVALUATED
and len(evaluation_result.per_invocation_results)
!= len(eval_metric_result_per_invocation)
):
raise ValueError(
'Eval metric should return results for each invocation. Found '
f'{len(evaluation_result.per_invocation_results)} results for '
f'{len(eval_metric_result_per_invocation)} invocations.'
)
# Track score across individual invocations.
for idx, invocation in enumerate(eval_metric_result_per_invocation):
invocation_result = (
evaluation_result.per_invocation_results[idx]
if evaluation_result.overall_eval_status != EvalStatus.NOT_EVALUATED
else PerInvocationResult(
actual_invocation=invocation.actual_invocation
)
)
eval_metric_result_details = EvalMetricResultDetails(
rubric_scores=invocation_result.rubric_scores
)
invocation.eval_metric_results.append(
EvalMetricResult(
score=invocation_result.score,
eval_status=invocation_result.eval_status,
details=eval_metric_result_details,
**eval_metric.model_dump(),
)
)
async def _evaluate_metric(
self,
eval_metric: EvalMetric,
actual_invocations: list[Invocation],
expected_invocations: Optional[list[Invocation]],
conversation_scenario: Optional[ConversationScenario],
) -> EvaluationResult:
"""Returns EvaluationResult obtained from evaluating a metric using an Evaluator."""
# Get the metric evaluator from the registry.
metric_evaluator = self._metric_evaluator_registry.get_evaluator(
eval_metric=eval_metric
)
if inspect.iscoroutinefunction(metric_evaluator.evaluate_invocations):
# Some evaluators could be async, for example those that use llm as a
# judge, so we need to make sure that we wait on them.
return await metric_evaluator.evaluate_invocations(
actual_invocations=actual_invocations,
expected_invocations=expected_invocations,
conversation_scenario=conversation_scenario,
)
else:
# Metrics that perform computation synchronously, mostly these don't
# perform any i/o. An example of this would calculation of rouge_1 score.
return metric_evaluator.evaluate_invocations(
actual_invocations=actual_invocations,
expected_invocations=expected_invocations,
conversation_scenario=conversation_scenario,
)
def _generate_final_eval_status(
self, overall_eval_metric_results: list[EvalMetricResult]
) -> EvalStatus:
final_eval_status = EvalStatus.NOT_EVALUATED
# Go over all the eval statuses and mark the final eval status as
# passed if all of them pass; otherwise, mark the final eval status to
# failed.
for overall_eval_metric_result in overall_eval_metric_results:
overall_eval_status = overall_eval_metric_result.eval_status
if overall_eval_status == EvalStatus.PASSED:
final_eval_status = EvalStatus.PASSED
elif overall_eval_status == EvalStatus.NOT_EVALUATED:
continue
elif overall_eval_status == EvalStatus.FAILED:
final_eval_status = EvalStatus.FAILED
break
else:
raise ValueError(f'Unknown eval status: {overall_eval_status}.')
return final_eval_status
async def _perform_inference_single_eval_item(
self,
app_name: str,
eval_set_id: str,
eval_case: EvalCase,
root_agent: BaseAgent,
use_live: bool,
live_timeout_seconds: int,
) -> InferenceResult:
initial_session = eval_case.session_input
session_id = self._session_id_supplier()
inference_result = InferenceResult(
app_name=app_name,
eval_set_id=eval_set_id,
eval_case_id=eval_case.eval_id,
session_id=session_id,
)
try:
with client_label_context(EVAL_CLIENT_LABEL):
if use_live:
inferences = await EvaluationGenerator._generate_inferences_from_root_agent_live(
root_agent=root_agent,
user_simulator=self._user_simulator_provider.provide(eval_case),
initial_session=initial_session,
session_id=session_id,
session_service=self._session_service,
artifact_service=self._artifact_service,
memory_service=self._memory_service,
live_timeout_seconds=live_timeout_seconds,
)
else:
inferences = (
await EvaluationGenerator._generate_inferences_from_root_agent(
root_agent=root_agent,
user_simulator=self._user_simulator_provider.provide(
eval_case
),
initial_session=initial_session,
session_id=session_id,
session_service=self._session_service,
artifact_service=self._artifact_service,
memory_service=self._memory_service,
)
)
inference_result.inferences = inferences
inference_result.status = InferenceStatus.SUCCESS
return inference_result
except Exception as e:
# We intentionally catch the Exception as we don't failures to affect
# other inferences.
logger.error(
'Inference failed for eval case `%s` with error %s.',
eval_case.eval_id,
e,
exc_info=True,
)
inference_result.status = InferenceStatus.FAILURE
inference_result.error_message = str(e)
return inference_result