Skip to content

Commit 0ad9bf4

Browse files
committed
fix: address review findings from 5 agents (18 items)
Pre-reviewed by 5 agents, 18 findings addressed: Security (Critical): - XML-escape all interpolated values in LLM prompt (prevent tag injection) - Strip action_type/tool_name through InformationStripper before LLM - Pass stripped description (not raw verdict.reason) to uncertainty checker - Extend control-char regex to cover Unicode bidi overrides Correctness (Major): - Remove MemoryError/RecursionError re-raise inside TaskGroup (prevents ExceptionGroup propagation) - Clamp confidence score to max 1.0 (floating-point edge case) - Filter empty/None provider responses from similarity computation - Add uncertainty_check_error sentinel to metadata on failure - Fix auto_reject_blocked=False path (was always auto-rejecting) - Change _parse_response param from object to CompletionResponse - Change _run_safety_classifier return type to bool (clearer contract) Frontend: - Replace IIFEs with precomputed variables (ESLint React Compiler rule) - Add NaN guard for parseFloat on confidence scores - Remove misleading 'Show original' toggle (description IS stripped) Tests: - Add factory wiring tests for SafetyClassifier/UncertaintyChecker - Add auto_reject_blocked=False test - Fix timeout test: asyncio.Event().wait() instead of sleep(100)
1 parent 65814e6 commit 0ad9bf4

8 files changed

Lines changed: 314 additions & 94 deletions

File tree

src/synthorg/security/safety_classifier.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
"""
2020

2121
import asyncio
22+
import html
2223
import re
2324
import time
2425
from enum import StrEnum
@@ -37,7 +38,12 @@
3738
)
3839
from synthorg.providers.enums import MessageRole
3940
from synthorg.providers.family import get_family, providers_excluding_family
40-
from synthorg.providers.models import ChatMessage, CompletionConfig, ToolDefinition
41+
from synthorg.providers.models import (
42+
ChatMessage,
43+
CompletionConfig,
44+
CompletionResponse,
45+
ToolDefinition,
46+
)
4147
from synthorg.security.config import SafetyClassifierConfig # noqa: TC001
4248
from synthorg.security.rules.credential_detector import CREDENTIAL_PATTERNS
4349
from synthorg.security.rules.data_leak_detector import PII_PATTERNS
@@ -80,8 +86,17 @@
8086
# Maximum length for LLM-returned reason string.
8187
_MAX_REASON_LENGTH: Final[int] = 300
8288

83-
# Regex to strip control characters from LLM-returned reason.
84-
_CONTROL_CHAR_RE: Final[re.Pattern[str]] = re.compile(r"[\x00-\x1f\x7f]")
89+
# Regex to strip control and formatting characters from LLM-returned
90+
# reason. Covers ASCII control (C0/DEL), Unicode bidi overrides
91+
# (U+200E-200F, U+202A-202E, U+2066-2069), and zero-width chars.
92+
_CONTROL_CHAR_RE: Final[re.Pattern[str]] = re.compile(
93+
r"[\x00-\x1f\x7f"
94+
r"\u200b-\u200f" # zero-width and bidi marks
95+
r"\u202a-\u202e" # bidi embedding/override
96+
r"\u2066-\u2069" # bidi isolate
97+
r"\ufeff" # BOM / zero-width no-break space
98+
r"]",
99+
)
85100

86101

87102
# ── Enums and models ─────────────────────────────────────────────
@@ -374,7 +389,7 @@ def _select_provider(
374389
if not available:
375390
return None, None
376391

377-
# Try cross-family from a random starting point.
392+
# Try cross-family selection.
378393
for name in available:
379394
family = get_family(name, self._configs)
380395
cross = providers_excluding_family(family, self._configs)
@@ -406,13 +421,22 @@ def _build_messages(
406421
tool_name: str,
407422
risk_level: ApprovalRiskLevel,
408423
) -> list[ChatMessage]:
409-
"""Build prompt messages from the stripped context."""
424+
"""Build prompt messages from the stripped context.
425+
426+
All interpolated values are XML-escaped to prevent tag
427+
injection from agent-controlled fields, and stripped of
428+
PII/secrets via the same ``InformationStripper``.
429+
"""
430+
safe_tool = html.escape(self._stripper.strip(tool_name))
431+
safe_type = html.escape(self._stripper.strip(action_type))
432+
safe_risk = html.escape(risk_level.value)
433+
safe_desc = html.escape(stripped_description)
410434
user_content = (
411435
"<action>\n"
412-
f" <tool>{tool_name}</tool>\n"
413-
f" <type>{action_type}</type>\n"
414-
f" <risk_level>{risk_level.value}</risk_level>\n"
415-
f" <description>{stripped_description}</description>\n"
436+
f" <tool>{safe_tool}</tool>\n"
437+
f" <type>{safe_type}</type>\n"
438+
f" <risk_level>{safe_risk}</risk_level>\n"
439+
f" <description>{safe_desc}</description>\n"
416440
"</action>"
417441
)
418442

@@ -427,14 +451,14 @@ def _build_messages(
427451

428452
def _parse_response(
429453
self,
430-
response: object,
454+
response: CompletionResponse,
431455
stripped_description: str,
432456
start: float,
433457
) -> SafetyClassifierResult:
434458
"""Parse LLM response into a SafetyClassifierResult."""
435459
duration_ms = (time.monotonic() - start) * 1000
436460

437-
for tc in response.tool_calls: # type: ignore[attr-defined]
461+
for tc in response.tool_calls:
438462
if tc.name == "safety_classification_verdict":
439463
return self._parse_tool_call(
440464
tc.arguments,

src/synthorg/security/service.py

Lines changed: 45 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@
6060
)
6161
from synthorg.security.output_scanner import OutputScanner # noqa: TC001
6262
from synthorg.security.rules.engine import RuleEngine # noqa: TC001
63+
from synthorg.security.safety_classifier import (
64+
SafetyClassification,
65+
)
6366
from synthorg.security.timeout.protocol import RiskTierClassifier # noqa: TC001
6467

6568
if TYPE_CHECKING:
@@ -564,32 +567,36 @@ async def _handle_escalation(
564567

565568
# Stage 1+2: safety classification (if configured).
566569
if self._safety_classifier is not None:
567-
classify_result = await self._run_safety_classifier(
570+
auto_rejected = await self._run_safety_classifier(
568571
context,
569572
verdict,
570573
metadata,
571574
)
572-
if classify_result is not None:
573-
# BLOCKED -> auto-reject.
574-
if classify_result == "blocked":
575-
return verdict.model_copy(
576-
update={
577-
"verdict": SecurityVerdictType.DENY,
578-
"reason": (
579-
f"{verdict.reason} (auto-rejected: "
580-
f"safety classifier blocked)"
581-
),
582-
},
583-
)
584-
# Use stripped description for the reviewer view.
585-
stripped = metadata.get("stripped_description")
586-
if stripped:
587-
description = stripped
575+
if auto_rejected:
576+
return verdict.model_copy(
577+
update={
578+
"verdict": SecurityVerdictType.DENY,
579+
"reason": (
580+
f"{verdict.reason} (auto-rejected: "
581+
f"safety classifier blocked)"
582+
),
583+
},
584+
)
585+
# Use stripped description for the reviewer view.
586+
stripped = metadata.get("stripped_description")
587+
if stripped:
588+
description = stripped
588589

589590
# Cross-provider uncertainty check (if configured).
591+
# Use stripped description when available to avoid
592+
# broadcasting raw PII/secrets to all providers.
590593
if self._uncertainty_checker is not None:
594+
check_text = metadata.get(
595+
"stripped_description",
596+
verdict.reason,
597+
)
591598
await self._run_uncertainty_check(
592-
verdict,
599+
check_text,
593600
metadata,
594601
)
595602

@@ -637,17 +644,16 @@ async def _run_safety_classifier(
637644
context: SecurityContext,
638645
verdict: SecurityVerdict,
639646
metadata: dict[str, str],
640-
) -> str | None:
647+
) -> bool:
641648
"""Run the safety classifier and populate metadata.
642649
643-
Returns the classification string, or ``None`` on error.
644-
On BLOCKED, the caller should auto-reject.
650+
Returns ``True`` if the action was auto-rejected (BLOCKED
651+
with ``auto_reject_blocked`` enabled), ``False`` otherwise.
652+
Metadata is populated with classification results on success.
653+
On error, metadata is left unchanged and ``False`` is returned
654+
(fail-safe: proceed to human review).
645655
"""
646656
try:
647-
from synthorg.security.safety_classifier import ( # noqa: PLC0415
648-
SafetyClassification,
649-
)
650-
651657
result = await self._safety_classifier.classify( # type: ignore[union-attr]
652658
verdict.reason,
653659
context.action_type,
@@ -658,15 +664,16 @@ async def _run_safety_classifier(
658664
metadata["stripped_description"] = result.stripped_description
659665
metadata["safety_reason"] = result.reason
660666

661-
if result.classification == SafetyClassification.BLOCKED:
662-
auto_reject = self._config.safety_classifier.auto_reject_blocked
663-
if auto_reject:
664-
logger.warning(
665-
SECURITY_SAFETY_CLASSIFY_BLOCKED,
666-
tool_name=context.tool_name,
667-
reason=result.reason,
668-
)
669-
return "blocked"
667+
if (
668+
result.classification == SafetyClassification.BLOCKED
669+
and self._config.safety_classifier.auto_reject_blocked
670+
):
671+
logger.warning(
672+
SECURITY_SAFETY_CLASSIFY_BLOCKED,
673+
tool_name=context.tool_name,
674+
reason=result.reason,
675+
)
676+
return True
670677
except MemoryError, RecursionError:
671678
raise
672679
except Exception:
@@ -675,19 +682,17 @@ async def _run_safety_classifier(
675682
tool_name=context.tool_name,
676683
note="Safety classifier failed -- proceeding without classification",
677684
)
678-
return None
679-
else:
680-
return result.classification.value
685+
return False
681686

682687
async def _run_uncertainty_check(
683688
self,
684-
verdict: SecurityVerdict,
689+
prompt: str,
685690
metadata: dict[str, str],
686691
) -> None:
687692
"""Run the uncertainty checker and populate metadata."""
688693
try:
689694
result = await self._uncertainty_checker.check( # type: ignore[union-attr]
690-
verdict.reason,
695+
prompt,
691696
)
692697
metadata["confidence_score"] = str(result.confidence_score)
693698
if result.keyword_overlap is not None:
@@ -703,3 +708,4 @@ async def _run_uncertainty_check(
703708
SECURITY_UNCERTAINTY_CHECK_ERROR,
704709
note="Uncertainty check failed -- proceeding without score",
705710
)
711+
metadata["uncertainty_check_error"] = "true"

src/synthorg/security/uncertainty.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ async def check(self, prompt: str) -> UncertaintyResult:
306306
# Compute similarity metrics.
307307
keyword_overlap = _compute_keyword_overlap(responses)
308308
embedding_sim = _compute_tfidf_cosine_similarity(responses)
309-
confidence = 0.6 * embedding_sim + 0.4 * keyword_overlap
309+
confidence = min(1.0, 0.6 * embedding_sim + 0.4 * keyword_overlap)
310310

311311
if confidence < self._config.low_confidence_threshold:
312312
logger.warning(
@@ -354,6 +354,13 @@ async def _collect_responses(
354354
results: list[str] = []
355355

356356
async def _call_provider(candidate: ResolvedModel) -> str | None:
357+
"""Call a single provider.
358+
359+
Inside a TaskGroup, all exceptions must be caught to
360+
avoid ExceptionGroup propagation (even MemoryError /
361+
RecursionError -- re-raising them would wrap in an
362+
ExceptionGroup that escapes outer except clauses).
363+
"""
357364
driver: BaseCompletionProvider = self._registry.get(
358365
candidate.provider_name,
359366
)
@@ -366,8 +373,6 @@ async def _call_provider(candidate: ResolvedModel) -> str | None:
366373
),
367374
timeout=self._config.timeout_seconds,
368375
)
369-
except MemoryError, RecursionError:
370-
raise
371376
except Exception:
372377
logger.exception(
373378
SECURITY_UNCERTAINTY_CHECK_ERROR,
@@ -376,7 +381,18 @@ async def _call_provider(candidate: ResolvedModel) -> str | None:
376381
)
377382
return None
378383
else:
379-
return response.content or ""
384+
# Filter empty/None content to avoid diluting
385+
# similarity metrics (e.g. content-filtered responses).
386+
text = response.content
387+
if not text:
388+
logger.debug(
389+
SECURITY_UNCERTAINTY_CHECK_ERROR,
390+
provider=candidate.provider_name,
391+
model=candidate.model_id,
392+
note="Provider returned empty content",
393+
)
394+
return None
395+
return text
380396

381397
async with asyncio.TaskGroup() as tg:
382398
tasks = [tg.create_task(_call_provider(c)) for c in candidates]

0 commit comments

Comments
 (0)