-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathmodel_client.py
More file actions
1644 lines (1432 loc) · 65.6 KB
/
model_client.py
File metadata and controls
1644 lines (1432 loc) · 65.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""LiteLLM-backed model helpers for the measurements pipeline.
Import side effects (intentional for the ``assert-ai`` CLI):
* ``_normalize_azure_api_base()`` rewrites ``AZURE_API_BASE`` at import
time to strip any ``/openai/...`` path suffix, logging at INFO when
it does so. Library users importing this module will see the env
var mutated for the rest of the process.
* The first activation of the Chat Completions fallback (either via
``ASSERT_PREFER_CHAT_COMPLETIONS=1`` at import time, an explicit
API preference, or a reactive recovery from a region error) calls
``_install_responses_api_guard()``, which monkey-patches
``litellm.main.responses_api_bridge_check`` for the rest of the
process. The patch is idempotent and forwards all positional and
keyword arguments to the original function so it survives minor
LiteLLM upgrades that add new kwargs.
"""
from __future__ import annotations
import asyncio
import contextlib
import contextvars
import importlib
import json
import logging
import os
import random
import time
from dataclasses import asdict, dataclass, field, is_dataclass, replace
from typing import Any, Iterator, Mapping, Sequence
log = logging.getLogger(__name__)
# ── Types ──────────────────────────────────────────────────────
MessageLike = "Message | Mapping[str, Any]"
ToolCallLike = "ToolCall | Mapping[str, Any]"
@dataclass(slots=True)
class ToolCall:
"""Normalized tool call representation."""
name: str
arguments: dict[str, Any] = field(default_factory=dict)
call_id: str | None = None
raw_arguments: str | None = None
type: str = "function"
raw: Any = None
@property
def id(self) -> str | None:
return self.call_id
@property
def function(self) -> str:
return self.name
def to_openai_dict(self) -> dict[str, Any]:
payload = {
"type": self.type,
"function": {
"name": self.name,
"arguments": self.raw_arguments or json.dumps(self.arguments),
},
}
if self.call_id:
payload["id"] = self.call_id
return payload
@dataclass(slots=True)
class Message:
"""OpenAI-style message shape."""
role: str
content: Any
name: str | None = None
tool_call_id: str | None = None
tool_calls: list[ToolCall] = field(default_factory=list)
@property
def text(self) -> str:
if isinstance(self.content, str):
return self.content
if isinstance(self.content, list):
parts: list[str] = []
for item in self.content:
if isinstance(item, dict) and isinstance(item.get("text"), str):
parts.append(item["text"])
elif isinstance(item, str):
parts.append(item)
return "\n".join(parts)
return str(self.content or "")
def to_openai_dict(self) -> dict[str, Any]:
payload = {
"role": self.role,
"content": self.content,
}
if self.name:
payload["name"] = self.name
if self.tool_call_id:
payload["tool_call_id"] = self.tool_call_id
if self.tool_calls:
payload["tool_calls"] = [tool_call.to_openai_dict() for tool_call in self.tool_calls]
return payload
@dataclass(slots=True)
class UsageStats:
"""Normalized token accounting.
``cached_input_tokens`` is the count of input tokens served from the
provider's prompt cache (a subset of ``prompt_tokens``). It surfaces
OpenAI/Azure ``prompt_tokens_details.cached_tokens`` and
Anthropic ``cache_read_input_tokens`` under one field so callers can
measure prefix-cache effectiveness without branching on the provider.
``cache_creation_input_tokens`` is Anthropic's "wrote new entries to
the cache" counter; it has no OpenAI/Azure equivalent because their
cache is implicit and free to write.
"""
prompt_tokens: int | None = None
completion_tokens: int | None = None
total_tokens: int | None = None
cached_input_tokens: int | None = None
cache_creation_input_tokens: int | None = None
raw: Any = None
@dataclass(slots=True)
class UsageAccumulator:
"""Aggregate token usage and call counts across many ``generate*`` calls.
Created by :func:`track_usage` and populated by the model client itself, so
callers don't have to thread per-call usage objects through their code.
Per-model breakdowns are tracked under ``per_model`` so a single stage that
invokes more than one model (e.g. test_set + stratification) can be inspected later.
"""
calls: int = 0
input_tokens: int = 0
output_tokens: int = 0
cached_input_tokens: int = 0
cache_creation_input_tokens: int = 0
per_model: dict[str, dict[str, int]] = field(default_factory=dict)
def add(self, usage: UsageStats | None, *, model: str | None = None) -> None:
"""Fold one call's normalized usage into this accumulator."""
if usage is None:
return
self.calls += 1
ipt = int(usage.prompt_tokens or 0)
opt = int(usage.completion_tokens or 0)
cit = int(usage.cached_input_tokens or 0)
cct = int(usage.cache_creation_input_tokens or 0)
self.input_tokens += ipt
self.output_tokens += opt
self.cached_input_tokens += cit
self.cache_creation_input_tokens += cct
key = model or "?"
bucket = self.per_model.setdefault(
key,
{
"calls": 0,
"input_tokens": 0,
"output_tokens": 0,
"cached_input_tokens": 0,
"cache_creation_input_tokens": 0,
},
)
bucket["calls"] += 1
bucket["input_tokens"] += ipt
bucket["output_tokens"] += opt
bucket["cached_input_tokens"] += cit
bucket["cache_creation_input_tokens"] += cct
def cache_hit_rate(self) -> float:
"""Return cached_input_tokens / input_tokens, or 0.0 when no input tokens."""
if self.input_tokens <= 0:
return 0.0
return self.cached_input_tokens / self.input_tokens
def to_dict(self) -> dict[str, Any]:
"""JSON-serializable snapshot of this accumulator."""
return {
"calls": self.calls,
"input_tokens": self.input_tokens,
"output_tokens": self.output_tokens,
"cached_input_tokens": self.cached_input_tokens,
"cache_creation_input_tokens": self.cache_creation_input_tokens,
"cache_hit_rate": self.cache_hit_rate(),
"per_model": dict(self.per_model),
}
_USAGE_ACCUMULATOR: contextvars.ContextVar[UsageAccumulator | None] = contextvars.ContextVar(
"_assert_ai_usage_accumulator",
default=None,
)
@contextlib.contextmanager
def track_usage() -> Iterator[UsageAccumulator]:
"""Capture token usage from every ``generate*`` call within the block.
Uses a ``ContextVar`` so that ``asyncio.run(...)`` blocks invoked inside the
``with`` statement inherit the accumulator and concurrent tasks all add into
the same object. The accumulator only sees calls made on the same ``async``
stack (or the same thread) — independent threads or coroutines that are
started in a fresh context will not contribute.
"""
accumulator = UsageAccumulator()
token = _USAGE_ACCUMULATOR.set(accumulator)
try:
yield accumulator
finally:
_USAGE_ACCUMULATOR.reset(token)
def _record_usage(usage: UsageStats | None, *, model: str | None) -> None:
"""Push one normalized usage payload into the active accumulator, if any."""
accumulator = _USAGE_ACCUMULATOR.get()
if accumulator is None:
return
accumulator.add(usage, model=model)
@dataclass(slots=True)
class GenerateOptions:
"""Transport options shared across chat, tool, and responses calls."""
temperature: float | None = None
max_tokens: int | None = None
max_output_tokens: int | None = None
web_search: bool = False
reasoning_effort: str | None = None
tool_choice: str | dict[str, Any] | None = None
timeout_s: float | None = None
call_label: str | None = None
extra_kwargs: dict[str, Any] = field(default_factory=dict)
@dataclass(slots=True)
class ModelResponse:
"""Normalized model response."""
text: str = ""
content: Any = None
parsed: Any = None
reasoning: str = ""
tool_calls: list[ToolCall] = field(default_factory=list)
finish_reason: str | None = None
status: str | None = None
incomplete_details: Any = None
model: str | None = None
response_id: str | None = None
usage: UsageStats | None = None
api_mode: str | None = None
request_payload: dict[str, Any] | None = None
raw: Any = None
@property
def message(self) -> Message:
return Message(role="assistant", content=self.text, tool_calls=list(self.tool_calls))
# ── Transport helpers ──────────────────────────────────────────
_LITELLM_MODULE: Any | None = None
# Finish-reason values that indicate the model hit the output-token limit
# rather than completing its response. ``length`` is Chat Completions;
# ``max_tokens`` / ``max_output_tokens`` come from the OpenAI Responses API
# (surfaced via ``incomplete_details.reason`` in ``normalize_response``).
# We deliberately do NOT include bare ``"incomplete"`` (the Responses API
# ``status`` field) because that value covers any non-finished state -- including
# content-filter refusals or provider errors that won't be fixed by enlarging
# the token budget. Truncation must be confirmed by a specific ``reason``.
_TRUNCATED_FINISH_REASONS: frozenset[str] = frozenset({
"length",
"max_tokens",
"max_output_tokens",
})
def is_truncated_response(response: "ModelResponse") -> bool:
"""Return True iff *response* hit the model's output token limit.
Checks both the normalized ``finish_reason`` (which already prefers the
Responses API ``incomplete_details.reason`` over the ambiguous
``status='incomplete'``) and the raw ``incomplete_details.reason`` as a
belt-and-suspenders fallback for callers that bypass ``normalize_response``.
"""
reason = getattr(response, "finish_reason", None)
if isinstance(reason, str) and reason in _TRUNCATED_FINISH_REASONS:
return True
incomplete = getattr(response, "incomplete_details", None)
incomplete_reason = _get_value(incomplete, "reason")
return isinstance(incomplete_reason, str) and incomplete_reason in _TRUNCATED_FINISH_REASONS
def build_json_schema_response_format(
name: str,
schema: dict[str, Any],
*,
strict: bool = True,
) -> dict[str, Any]:
"""Build OpenAI-format JSON schema output config."""
return {
"type": "json_schema",
"json_schema": {
"name": name,
"strict": strict,
"schema": schema,
},
}
def build_json_schema_text_format(
name: str,
schema: dict[str, Any],
*,
strict: bool = True,
) -> dict[str, Any]:
"""Build Responses API JSON schema output config."""
return {
"type": "json_schema",
"name": name,
"strict": strict,
"schema": schema,
}
def _model_family(model: str) -> str:
"""Return the provider/model-family prefix used for transport capability checks."""
normalized = (model or "").strip().lower()
if "/" in normalized:
return normalized.split("/", 1)[0]
if normalized.startswith(("gpt-", "o1", "o3", "o4")):
return "openai"
return normalized
def _supports_web_search_preview(model: str) -> bool:
"""Whether this model can use the Responses API web_search_preview tool.
The current implementation sends OpenAI Responses API payloads with
``tools=[{"type": "web_search_preview"}]``. LiteLLM exposes that
path for OpenAI-compatible models, but provider smoke testing showed
Gemini fails before useful generation when this tool is combined with
structured output. Keep the gate intentionally narrow until each
provider has an explicit, tested web-search path.
"""
return _model_family(model) in {"openai", "azure"}
def _require_web_search_preview_support(model: str) -> None:
if _supports_web_search_preview(model):
return
raise ValueError(
"web_search uses the OpenAI Responses API web_search_preview tool, "
f"which is not enabled for model '{model}'. Disable web_search for this "
"stage or use an OpenAI/Azure OpenAI model."
)
def messages_to_openai(messages: str | Sequence[MessageLike]) -> list[dict[str, Any]]:
"""Convert message inputs into OpenAI-format dicts."""
if isinstance(messages, str):
return [Message(role="user", content=messages).to_openai_dict()]
result: list[dict[str, Any]] = []
for message in messages:
result.append(_coerce_message(message).to_openai_dict())
return result
def normalize_tool_calls(raw_tool_calls: Sequence[ToolCallLike] | None) -> list[ToolCall]:
"""Convert raw OpenAI/LiteLLM tool calls into ``ToolCall`` objects."""
normalized: list[ToolCall] = []
for raw_tool_call in raw_tool_calls or []:
function_payload = _get_value(raw_tool_call, "function") or raw_tool_call
raw_arguments = _get_value(function_payload, "arguments")
parsed_arguments: dict[str, Any] = {}
if isinstance(raw_arguments, str) and raw_arguments.strip():
try:
parsed = json.loads(raw_arguments)
if isinstance(parsed, dict):
parsed_arguments = parsed
except json.JSONDecodeError:
pass
elif isinstance(raw_arguments, dict):
parsed_arguments = dict(raw_arguments)
normalized.append(
ToolCall(
name=str(_get_value(function_payload, "name") or ""),
arguments=parsed_arguments,
call_id=_get_value(raw_tool_call, "id"),
raw_arguments=raw_arguments if isinstance(raw_arguments, str) else None,
type=str(_get_value(raw_tool_call, "type") or "function"),
raw=raw_tool_call,
)
)
return normalized
def normalize_response(
raw_response: Any,
*,
api_mode: str | None = None,
request_payload: dict[str, Any] | None = None,
) -> ModelResponse:
"""Normalize a LiteLLM/OpenAI-style response object."""
choice = _first_choice(raw_response)
message = _get_value(choice, "message")
content = _get_value(message, "content")
if content in (None, ""):
content = _get_value(raw_response, "output_text")
if content in (None, ""):
content = _extract_responses_output_text(_get_value(raw_response, "output"))
text = _extract_text_content(content)
reasoning = _extract_text_content(_get_value(message, "reasoning_content"))
parsed = _maybe_parse_json(text)
return ModelResponse(
text=text,
content=content,
parsed=parsed,
reasoning=reasoning,
tool_calls=normalize_tool_calls(_get_value(message, "tool_calls")),
finish_reason=(
_get_value(choice, "finish_reason")
or _get_value(raw_response, "stop_reason")
or _get_value(_get_value(raw_response, "incomplete_details"), "reason")
or _get_value(raw_response, "status")
),
status=_get_value(raw_response, "status"),
incomplete_details=_get_value(raw_response, "incomplete_details"),
model=_get_value(raw_response, "model"),
response_id=_get_value(raw_response, "id"),
usage=_normalize_usage(_get_value(raw_response, "usage")),
api_mode=api_mode,
request_payload=request_payload,
raw=raw_response,
)
def to_jsonable(value: Any) -> Any:
"""Convert provider payloads into JSON-safe structures for artifact storage."""
if value is None or isinstance(value, (str, int, float, bool)):
return value
if isinstance(value, Mapping):
return {str(key): to_jsonable(item) for key, item in value.items()}
if is_dataclass(value):
return to_jsonable(asdict(value))
if hasattr(value, "model_dump") and callable(value.model_dump):
return to_jsonable(value.model_dump())
if hasattr(value, "dict") and callable(value.dict):
return to_jsonable(value.dict())
if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
return [to_jsonable(item) for item in value]
if hasattr(value, "__dict__"):
return to_jsonable(vars(value))
return str(value)
def summarize_response(response: ModelResponse) -> dict[str, Any]:
"""Build a compact display summary for a model response."""
payload: dict[str, Any] = {
"content": response.text or "",
"stop_reason": response.finish_reason or "",
}
if response.tool_calls:
payload["tool_calls"] = [tool_call.to_openai_dict() for tool_call in response.tool_calls]
if response.usage:
usage_payload: dict[str, Any] = {
"input_tokens": response.usage.prompt_tokens,
"output_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens,
}
if response.usage.cached_input_tokens:
usage_payload["cached_input_tokens"] = response.usage.cached_input_tokens
if response.usage.cache_creation_input_tokens:
usage_payload["cache_creation_input_tokens"] = (
response.usage.cache_creation_input_tokens
)
payload["usage"] = usage_payload
if response.model:
payload["model"] = response.model
if response.response_id:
payload["response_id"] = response.response_id
return payload
def build_llm_call_trace(response: ModelResponse, *, source: str) -> dict[str, Any]:
"""Build an artifact-safe owned LLM call trace.
Sanitizes request payloads to prevent credential leakage in artifact files.
"""
from assert_ai.core.security import sanitize_payload
return {
"source": source,
"api_mode": response.api_mode or "",
"request": sanitize_payload(to_jsonable(response.request_payload or {})),
"response": to_jsonable(response.raw),
"derived": summarize_response(response),
}
def _coerce_message(message: MessageLike) -> Message:
if isinstance(message, Message):
return message
tool_calls = normalize_tool_calls(_get_value(message, "tool_calls"))
return Message(
role=str(_get_value(message, "role") or "user"),
content=_get_value(message, "content"),
name=_get_value(message, "name"),
tool_call_id=_get_value(message, "tool_call_id"),
tool_calls=tool_calls,
)
def _build_chat_payload(
model: str,
messages: str | Sequence[MessageLike],
options: GenerateOptions | None,
) -> dict[str, Any]:
resolved_options = options or GenerateOptions()
payload: dict[str, Any] = {
"model": model,
"messages": messages_to_openai(messages),
}
if resolved_options.temperature is not None:
payload["temperature"] = resolved_options.temperature
if resolved_options.max_tokens is not None:
payload["max_tokens"] = resolved_options.max_tokens
if resolved_options.max_output_tokens is not None and "max_tokens" not in payload:
payload["max_tokens"] = resolved_options.max_output_tokens
if resolved_options.reasoning_effort is not None:
payload["reasoning_effort"] = resolved_options.reasoning_effort
payload.update(resolved_options.extra_kwargs)
return payload
def _build_responses_payload(
model: str,
messages: str | Sequence[MessageLike],
options: GenerateOptions | None,
) -> dict[str, Any]:
resolved_options = options or GenerateOptions()
if isinstance(messages, str):
input_payload: Any = messages
else:
input_payload = messages_to_openai(messages)
payload: dict[str, Any] = {
"model": model,
"input": input_payload,
}
if resolved_options.temperature is not None:
payload["temperature"] = resolved_options.temperature
if resolved_options.max_output_tokens is not None:
payload["max_output_tokens"] = resolved_options.max_output_tokens
elif resolved_options.max_tokens is not None:
payload["max_output_tokens"] = resolved_options.max_tokens
if resolved_options.reasoning_effort is not None:
payload["reasoning_effort"] = resolved_options.reasoning_effort
payload.update(resolved_options.extra_kwargs)
return payload
def _responses_client(litellm: Any) -> tuple[Any, bool]:
if hasattr(litellm, "aresponses"):
return litellm.aresponses, True
if hasattr(litellm, "responses"):
return litellm.responses, False
raise ValueError("web_search requires a LiteLLM responses client")
def _get_litellm_module() -> Any:
global _LITELLM_MODULE
if _LITELLM_MODULE is None:
try:
_LITELLM_MODULE = importlib.import_module("litellm")
except ModuleNotFoundError as exc:
raise RuntimeError(
"litellm is not installed. Install it with `python -m pip install litellm` "
"before using assert_ai.core.model_client."
) from exc
# Silence noisy litellm warnings that pollute stderr
_LITELLM_MODULE.suppress_debug_info = True
# Disable LiteLLM's internal retry so _with_retries is the
# sole retry layer — avoids double-retry and lets the
# coordinated per-model cooldown work correctly.
_LITELLM_MODULE.num_retries = 0
# If the user has opted into Chat Completions proactively,
# disable the Responses API now so LiteLLM never attempts it.
if os.environ.get("ASSERT_PREFER_CHAT_COMPLETIONS", "").strip() in ("1", "true", "yes"):
_apply_chat_completions_preference()
log.info(
"ASSERT_PREFER_CHAT_COMPLETIONS is set; using Chat "
"Completions API for all models."
)
return _LITELLM_MODULE
async def _await_with_timeout(awaitable: Any, *, timeout_s: float | None) -> Any:
if timeout_s is None:
return await awaitable
async with asyncio.timeout(timeout_s):
return await awaitable
async def _run_sync_with_timeout(callable_obj: Any, *, timeout_s: float | None, **kwargs: Any) -> Any:
task = asyncio.to_thread(callable_obj, **kwargs)
return await _await_with_timeout(task, timeout_s=timeout_s)
# ── LiteLLM error classification ──────────────────────────────
class LLMAuthError(Exception):
"""Bad API key or credentials — not retryable."""
class LLMInputError(Exception):
"""Invalid request (prompt too long, bad params) — not retryable."""
class LLMContentFilterError(LLMInputError):
"""Provider-side content filter rejected the prompt — not retryable.
This is a *subclass* of LLMInputError so existing handlers that catch
LLMInputError still see content-filter rejections. Adversarial-eval
workloads (judge.py, rollout.py) catch this subclass specifically and
treat it as a soft per-row failure rather than aborting the run, since
sending content the provider will reject is the *whole point* of those
workloads. Routine eval workloads still see the parent LLMInputError
and fail loudly as before.
"""
class LLMRateLimitError(Exception):
"""Rate limited — retryable after backoff."""
class LLMProviderError(Exception):
"""Provider-side error (5xx) — may be retryable."""
class _ResponsesApiNotAvailableError(LLMProviderError):
"""Region does not support Azure Responses API — triggers automatic
fallback to Chat Completions for the remainder of the run.
Inherits from :class:`LLMProviderError` so that callers up the stack
that catch ``LLMProviderError`` (notably ``stages/inference.py`` and
``init/_design_agent.py``) still treat this as a real failure when
the in-loop fallback exhausts its retry and re-raises. Without this
inheritance the exception falls through generic ``except Exception``
catch-alls and silently produces empty content, which a judge then
happily scores ✓.
"""
# ── Responses API → Chat Completions fallback state ────────────
_responses_api_fallback_warned: bool = False
"""Set to True after the first Responses-API-not-available warning is
emitted so we only log the user-facing message once per run."""
_web_search_drop_warned: bool = False
"""Set to True after the first ``web_search`` degradation warning so
the message is emitted once per run rather than once per task."""
_force_chat_completions: bool = False
"""When True, the monkey-patched ``responses_api_bridge_check`` forces
``mode=chat`` so LiteLLM never routes through the Responses API bridge."""
_responses_api_guard_installed: bool = False
"""Set to True after the bridge-check monkey-patch has been installed."""
def _install_responses_api_guard() -> None:
"""Monkey-patch ``litellm.main.responses_api_bridge_check``.
LiteLLM 1.82+ auto-routes GPT-5.4+ calls that include both ``tools``
and ``reasoning_effort`` through the Responses API bridge (see
``litellm/main.py:responses_api_bridge_check``). There is no
kwarg or environment variable to opt out — the routing decision is
made inside the bridge check itself.
The patch wraps the original function and overrides the returned
``mode`` from ``"responses"`` back to ``"chat"`` whenever
``_force_chat_completions`` is True. Idempotent — safe to call
multiple times; the guard flag prevents double-wrapping.
"""
global _responses_api_guard_installed
if _responses_api_guard_installed:
return
from litellm import main as _litellm_main # noqa: WPS433
_original = _litellm_main.responses_api_bridge_check
# Accept ``*args, **kwargs`` and forward them as-is so the patch
# is forward-compatible with LiteLLM minor releases that add new
# parameters to ``responses_api_bridge_check``. Pinning a fixed
# signature here would silently drop any newly added kwargs and
# break Responses-API routing for callers that need it.
def _guarded_bridge_check(*args: Any, **kwargs: Any) -> tuple:
model_info, out_model = _original(*args, **kwargs)
if _force_chat_completions and model_info.get("mode") == "responses":
model_info["mode"] = "chat"
return model_info, out_model
_litellm_main.responses_api_bridge_check = _guarded_bridge_check
_responses_api_guard_installed = True
def _activate_chat_completions_fallback(
reason: str,
*,
model: str | None = None,
tag: str = "",
proactive: bool = False,
) -> None:
"""Activate process-wide Chat Completions fallback.
Sets the sticky ``_force_chat_completions`` flag, installs the
bridge-check guard (idempotent), and emits a single user-facing
message per run via ``_responses_api_fallback_warned``.
When ``proactive`` is True (env-var seed at import time, or an
explicit API preference) the message is logged at INFO and omits
the "set ASSERT_PREFER_CHAT_COMPLETIONS=1" hint — the user has
already opted in. When False (reactive recovery from a region
error) it is logged at WARN with the hint.
"""
global _force_chat_completions, _responses_api_fallback_warned
_force_chat_completions = True
_install_responses_api_guard()
if _responses_api_fallback_warned:
return
_responses_api_fallback_warned = True
prefix = f"{model}{tag}: " if model else ""
if proactive:
log.info(
"%sUsing Azure Chat Completions instead of Responses API "
"for this run (%s).",
prefix, reason,
)
else:
log.warning(
"%sFalling back from Azure Responses API to Chat Completions "
"for the remainder of this run (%s). Set "
"ASSERT_PREFER_CHAT_COMPLETIONS=1 to skip this round-trip "
"upfront in unsupported regions.",
prefix, reason,
)
def _apply_chat_completions_preference() -> None:
"""Public entry for proactive activation (kept for back-compat)."""
_activate_chat_completions_fallback("preference set via API", proactive=True)
def _drop_web_search_for_fallback(
options: "GenerateOptions", model: str, *, reason: str
) -> "GenerateOptions":
"""Disable ``web_search`` on ``options`` and warn once per run.
``web_search`` is implemented via the Responses API
``web_search_preview`` tool — there is no Chat Completions
equivalent on Azure. When the Responses API is unavailable in the
target region (or the Chat-Completions fallback is already active),
we degrade gracefully: the call still succeeds but without web
grounding. The first occurrence is logged loudly so the user knows
the run produced different output than a Responses-API-supporting
region would have.
"""
global _web_search_drop_warned
if not _web_search_drop_warned:
_web_search_drop_warned = True
tag = f" [{options.call_label}]" if options.call_label else ""
log.warning(
"%s%s: dropping web_search and routing via Chat Completions "
"(%s). Web grounding is disabled for the remainder of this run.",
model, tag, reason,
)
return replace(options, web_search=False)
def _classify_llm_error(exc: Exception) -> Exception:
"""Wrap litellm exceptions into categorized errors.
LiteLLM maps provider HTTP errors to OpenAI-compatible exception types.
We re-classify them into four categories that drive retry behaviour in
``_with_retries``. The mapping below is based on the LiteLLM exception
table (https://docs.litellm.ai/docs/exception_mapping).
HTTP LiteLLM exception → ASSERT class (retried?)
───── ──────────────────────────────────────── ─────────────────────────
400 BadRequestError → LLMInputError (no)
├─ ContextWindowExceededError → LLMInputError (no)
├─ ContentPolicyViolationError → LLMInputError (no)
├─ UnsupportedParamsError → LLMInputError (no)
└─ ImageFetchError → LLMInputError (no)
401 AuthenticationError → LLMAuthError (no)
403 PermissionDeniedError → LLMAuthError (no)
404 NotFoundError → LLMInputError (no)
408 Timeout (inherits APIConnectionError) → LLMProviderError (yes)
422 UnprocessableEntityError → LLMInputError (no)
429 RateLimitError → LLMRateLimitError (yes, coordinated)
500 APIError / APIConnectionError → LLMProviderError (yes)
503 ServiceUnavailableError (inherits APIError) → LLMProviderError (yes)
≥500 InternalServerError (inherits APIError) → LLMProviderError (yes)
N/A APIResponseValidationError → LLMInputError (no)
N/A BudgetExceededError → (falls through as-is)
Check order matters: specific subclasses must be tested before their
bases (e.g. NotFoundError before APIError, since NotFoundError inherits
from APIStatusError → APIError on some providers).
"""
litellm = _get_litellm_module()
# 401 — bad credentials
if isinstance(exc, litellm.AuthenticationError):
err = LLMAuthError(f"Authentication failed: {exc}")
err.__cause__ = exc
return err
# 403 — insufficient permissions
if isinstance(exc, getattr(litellm, "PermissionDeniedError", ())):
err = LLMAuthError(f"Permission denied: {exc}")
err.__cause__ = exc
return err
# 429 — rate limited (retryable with coordinated backoff)
if isinstance(exc, litellm.RateLimitError):
err = LLMRateLimitError(f"Rate limited: {exc}")
err.__cause__ = exc
return err
# ── Responses API region fallback (must precede BadRequestError) ──
# Azure OpenAI rejects Responses API requests in unsupported regions
# (West Europe etc.) with one of these observed messages:
# - "API version not supported" (HTTP 400 → BadRequestError)
# - "responses api is not enabled" (HTTP 404 → NotFoundError/APIError)
# We check before the BadRequestError / NotFoundError / APIError handlers
# so the error routes to the Chat-Completions fallback in
# ``_with_retries`` instead of being recorded as an unrecoverable bad
# request.
#
# Important: do NOT gate on ``_force_chat_completions`` here. At high
# concurrency the first task to hit this marker activates fallback and
# installs the bridge-check patch, but other tasks already in-flight
# against the Responses API will surface the same marker shortly
# afterwards. Gating the check on ``_force_chat_completions`` caused
# those in-flight failures to fall through to the NotFoundError /
# BadRequestError handlers, get classified as ``LLMInputError``, and
# propagate without ever being retried on the Chat path. Per-task
# loop prevention is handled inside ``_with_retries`` via a
# ``chat_fallback_attempts`` counter.
_msg_lower = str(exc).lower()
if any(
marker in _msg_lower
for marker in (
"responses api is not enabled",
"api version not supported",
)
):
err = _ResponsesApiNotAvailableError(
f"Azure Responses API not available: {exc}"
)
err.__cause__ = exc
return err
# 400 — bad request (includes ContextWindowExceeded, ContentPolicyViolation)
if isinstance(exc, litellm.BadRequestError):
# ContentPolicyViolationError is a BadRequestError subclass on most
# providers; some providers (notably Azure OpenAI) instead surface
# the content filter via a generic BadRequestError whose message
# contains a stable marker. We classify both paths as a
# LLMContentFilterError subclass so adversarial-eval workloads can
# tolerate them per-row.
#
# Observed Azure / OpenAI variants (each can appear independently):
# - "content_filter" / "content filter"
# - "ResponsibleAIPolicyViolation"
# - "high-risk cyber activity" / "potentially high-risk"
# - "your prompt was flagged" + "usage policy"
# - "Invalid prompt: ..." (OpenAI reasoning-models guard text)
cf_cls = getattr(litellm, "ContentPolicyViolationError", None)
msg_text = str(exc)
msg_lower = msg_text.lower()
is_content_filter = (
cf_cls is not None and isinstance(exc, cf_cls)
) or any(
marker in msg_lower
for marker in (
"content_filter",
"content filter",
"responsibleaipolicyviolation",
"high-risk cyber activity",
"potentially high-risk",
"flagged as potentially violating",
"violating our usage policy",
"prompt was flagged",
"invalid prompt:",
)
)
if is_content_filter:
err = LLMContentFilterError(f"Content filtered: {exc}")
else:
err = LLMInputError(f"Bad request: {exc}")
err.__cause__ = exc
return err
# 404 — model/deployment not found
if isinstance(exc, litellm.NotFoundError):
err = LLMInputError(f"Model/deployment not found: {exc}")
err.__cause__ = exc
return err
# 422 — unprocessable entity
if isinstance(exc, getattr(litellm, "UnprocessableEntityError", ())):
err = LLMInputError(f"Unprocessable entity: {exc}")
err.__cause__ = exc
return err
# N/A — response schema validation failure
if isinstance(exc, getattr(litellm, "APIResponseValidationError", ())):
err = LLMInputError(f"Response validation failed: {exc}")
err.__cause__ = exc
return err
# 500/503/≥500/timeout/connection — retryable provider errors
# This is the catch-all for APIError, APIConnectionError,
# InternalServerError, ServiceUnavailableError, and Timeout
# (all inherit from APIError or APIConnectionError).
if isinstance(exc, (litellm.APIError, litellm.APIConnectionError)):
err = LLMProviderError(f"Provider error: {exc}")
err.__cause__ = exc
return err
return exc
def _first_choice(raw_response: Any) -> Any:
choices = _get_value(raw_response, "choices")
if isinstance(choices, Sequence) and not isinstance(choices, (str, bytes)) and choices:
return choices[0]
return None
def _normalize_usage(raw_usage: Any) -> UsageStats | None:
if raw_usage is None:
return None
# Chat Completions API uses prompt_tokens/completion_tokens;
# Responses API uses input_tokens/output_tokens.
prompt = _coerce_int(_get_value(raw_usage, "prompt_tokens")) or _coerce_int(_get_value(raw_usage, "input_tokens"))
completion = _coerce_int(_get_value(raw_usage, "completion_tokens")) or _coerce_int(_get_value(raw_usage, "output_tokens"))
total = _coerce_int(_get_value(raw_usage, "total_tokens"))
if total is None and prompt is not None and completion is not None:
total = prompt + completion
# Prompt-cache accounting. OpenAI/Azure expose cached prompt tokens
# under {prompt,input}_tokens_details.cached_tokens (Chat Completions
# vs Responses API). Anthropic exposes them as top-level
# cache_read_input_tokens / cache_creation_input_tokens. LiteLLM
# passes both shapes through unchanged.
cached_input = _coerce_int(_get_value(raw_usage, "cache_read_input_tokens"))
if cached_input is None:
prompt_details = (
_get_value(raw_usage, "prompt_tokens_details")
or _get_value(raw_usage, "input_tokens_details")
)
if prompt_details is not None:
cached_input = _coerce_int(_get_value(prompt_details, "cached_tokens"))
cache_creation = _coerce_int(_get_value(raw_usage, "cache_creation_input_tokens"))
# Diagnose providers that return a usage object but with all zero/None
# token counts. Seen with truncated Azure Responses API calls (status
# 'incomplete'): the accumulator records 1 call but reports 0 in / 0 out,
# which masks the real token spend. Log at debug so we can attribute
# mystery zero-token rows in the future without spamming normal runs.
if not prompt and not completion and not total:
log.debug(
"Usage payload contained zero/None tokens (raw=%r) -- "
"provider likely returned an incomplete or error response",
raw_usage,
)
return UsageStats(
prompt_tokens=prompt,
completion_tokens=completion,
total_tokens=total,
cached_input_tokens=cached_input,
cache_creation_input_tokens=cache_creation,
raw=raw_usage,
)