Skip to content

Commit 7490cc3

Browse files
code-review-botampagent
andcommitted
code review fixes for PR NousResearch#13
- safe_print: strip rich markup tags when rich is unavailable so plain fallback output doesn't leak literal '[bold red]...[/bold red]' tags - profiling.aggregate_profiling_stats: stop reconstructing per-call timings by repeating the mean (statistically wrong; gave incorrect min/max/median across workers). Combine summary stats directly and flag median as approximate via median_time_approximate - run_agent: hoist 'reset_profiler' import to module level; drop noisy 'dir(tc)' / 'model_dump()' debug logging spam from the verbose path - batch_runner: drop unused 'import re'; replace bare 'except:' with specific (JSONDecodeError, TypeError, AttributeError) and guard isinstance(content, str) before calling .strip() - tools/simple_terminal_tool: replace bare 'except:' with 'except Exception' on the SSH-context cleanup paths Amp-Thread-ID: https://ampcode.com/threads/T-019dce4d-5fc2-703c-b2e4-b8a87ec42105 Co-authored-by: Amp <amp@ampcode.com>
1 parent a219e17 commit 7490cc3

5 files changed

Lines changed: 87 additions & 71 deletions

File tree

batch_runner.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
from datetime import datetime
3636
from multiprocessing import Pool, Manager, Lock
3737
import traceback
38-
import re
3938

4039
import fire
4140

@@ -187,9 +186,9 @@ def _extract_tool_errors_from_messages(messages: List[Dict[str, Any]]) -> List[D
187186
if not error_msg:
188187
error_msg = str(content_json.get("message", content_json.get("error", "Unknown error")))
189188

190-
except:
189+
except (json.JSONDecodeError, TypeError, AttributeError):
191190
# If not JSON, check if content explicitly states an error
192-
if content.strip().lower().startswith("error:"):
191+
if isinstance(content, str) and content.strip().lower().startswith("error:"):
193192
has_error = True
194193
error_msg = content.strip()
195194

@@ -275,13 +274,13 @@ def _extract_tool_stats(messages: List[Dict[str, Any]]) -> Dict[str, Dict[str, i
275274
if content_json.get("success") is False:
276275
is_success = False
277276

278-
except:
277+
except (json.JSONDecodeError, TypeError, AttributeError):
279278
# If not JSON, check if content is empty or explicitly states an error
280279
# Note: We avoid simple substring matching to prevent false positives
281280
if not content:
282281
is_success = False
283282
# Only mark as failure if it explicitly starts with "Error:" or "ERROR:"
284-
elif content.strip().lower().startswith("error:"):
283+
elif isinstance(content, str) and content.strip().lower().startswith("error:"):
285284
is_success = False
286285

287286
# Update success/failure count

profiling.py

Lines changed: 56 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -260,64 +260,68 @@ def aggregate_profiling_stats(stats_list: List[Dict]) -> Dict:
260260
Returns:
261261
Dict: Aggregated statistics with combined tool and API call data
262262
"""
263-
aggregated = {
264-
"tools": defaultdict(lambda: {"times": []}),
265-
"api_calls": {"times": []}
266-
}
267-
268-
# Aggregate tool statistics
269-
for stats in stats_list:
270-
# Aggregate tool timings
271-
for tool_name, tool_stats in stats.get("tools", {}).items():
272-
# Reconstruct individual timings from aggregated stats
273-
# Since we have mean_time and call_count, we approximate
274-
aggregated["tools"][tool_name]["times"].extend(
275-
[tool_stats.get("mean_time", 0.0)] * tool_stats.get("call_count", 0)
276-
)
277-
278-
# Aggregate API call timings
279-
api_stats = stats.get("api_calls", {})
280-
if api_stats.get("call_count", 0) > 0:
281-
aggregated["api_calls"]["times"].extend(
282-
[api_stats.get("mean_time", 0.0)] * api_stats.get("call_count", 0)
283-
)
284-
285-
# Calculate final statistics for tools
286-
final_stats = {"tools": {}, "api_calls": {}}
287-
288-
for tool_name, data in aggregated["tools"].items():
289-
times = data["times"]
290-
if times:
291-
final_stats["tools"][tool_name] = {
292-
"call_count": len(times),
293-
"total_time": sum(times),
294-
"min_time": min(times),
295-
"max_time": max(times),
296-
"mean_time": statistics.mean(times),
297-
"median_time": statistics.median(times)
298-
}
263+
# Note: per-call timings are not preserved across worker boundaries, so we
264+
# combine the per-conversation summary stats directly. This gives correct
265+
# call_count/total_time/min/max/mean. ``median`` cannot be reconstructed
266+
# exactly from summaries; we surface mean as a best-effort approximation
267+
# and flag it via the ``median_time_approximate`` field.
299268

300-
# Calculate final statistics for API calls
301-
api_times = aggregated["api_calls"]["times"]
302-
if api_times:
303-
final_stats["api_calls"] = {
304-
"call_count": len(api_times),
305-
"total_time": sum(api_times),
306-
"min_time": min(api_times),
307-
"max_time": max(api_times),
308-
"mean_time": statistics.mean(api_times),
309-
"median_time": statistics.median(api_times)
310-
}
311-
else:
312-
final_stats["api_calls"] = {
269+
def _empty():
270+
return {
313271
"call_count": 0,
314272
"total_time": 0.0,
315-
"min_time": 0.0,
273+
"min_time": float("inf"),
316274
"max_time": 0.0,
317-
"mean_time": 0.0,
318-
"median_time": 0.0
319275
}
320276

277+
tool_acc: Dict[str, Dict] = defaultdict(_empty)
278+
api_acc = _empty()
279+
280+
def _merge(acc: Dict, summary: Dict) -> None:
281+
count = summary.get("call_count", 0)
282+
if count <= 0:
283+
return
284+
acc["call_count"] += count
285+
acc["total_time"] += summary.get("total_time", 0.0)
286+
# Only consider min/max if the source actually had calls; otherwise
287+
# its min_time will be the sentinel 0.0 from to_dict().
288+
acc["min_time"] = min(acc["min_time"], summary.get("min_time", float("inf")))
289+
acc["max_time"] = max(acc["max_time"], summary.get("max_time", 0.0))
290+
291+
for stats in stats_list:
292+
for tool_name, tool_stats in stats.get("tools", {}).items():
293+
_merge(tool_acc[tool_name], tool_stats)
294+
_merge(api_acc, stats.get("api_calls", {}))
295+
296+
def _finalize(acc: Dict) -> Dict:
297+
count = acc["call_count"]
298+
if count == 0:
299+
return {
300+
"call_count": 0,
301+
"total_time": 0.0,
302+
"min_time": 0.0,
303+
"max_time": 0.0,
304+
"mean_time": 0.0,
305+
"median_time": 0.0,
306+
"median_time_approximate": True,
307+
}
308+
mean_time = acc["total_time"] / count
309+
return {
310+
"call_count": count,
311+
"total_time": acc["total_time"],
312+
"min_time": acc["min_time"] if acc["min_time"] != float("inf") else 0.0,
313+
"max_time": acc["max_time"],
314+
"mean_time": mean_time,
315+
# Real median requires per-call data we don't carry across workers.
316+
"median_time": mean_time,
317+
"median_time_approximate": True,
318+
}
319+
320+
final_stats = {
321+
"tools": {name: _finalize(acc) for name, acc in tool_acc.items()},
322+
"api_calls": _finalize(api_acc),
323+
}
324+
321325
return final_stats
322326

323327

run_agent.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
from tools.terminal_tool import cleanup_vm
4747

4848
# Import profiling
49-
from profiling import get_profiler
49+
from profiling import get_profiler, reset_profiler
5050

5151

5252
class AIAgent:
@@ -368,8 +368,7 @@ def run_conversation(
368368
Dict: Complete conversation result with final response and message history
369369
"""
370370
# Reset profiler for this conversation to get fresh stats
371-
from profiling import reset_profiler as reset_prof
372-
reset_prof()
371+
reset_profiler()
373372

374373
# Generate unique task_id if not provided to isolate VMs between concurrent tasks
375374
import uuid
@@ -461,11 +460,6 @@ def run_conversation(
461460
if self.verbose_logging:
462461
for tc in assistant_message.tool_calls:
463462
logging.debug(f"Tool call: {tc.function.name} with args: {tc.function.arguments[:200]}...")
464-
# Debug: Check what attributes are available on tool_call
465-
logging.debug(f"Tool call attributes: {dir(tc)}")
466-
# Try to dump the model to see all fields
467-
if hasattr(tc, 'model_dump'):
468-
logging.debug(f"Tool call data: {tc.model_dump()}")
469463

470464
# Add assistant message with tool calls to conversation
471465
# Extract thought_signature if present (required for Gemini models)

safe_print.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
#!/usr/bin/env python3
2-
"""Simple safe print that tries rich, falls back to regular print."""
2+
"""Simple safe print that tries rich, falls back to regular print.
3+
4+
When rich is unavailable, any rich-style markup like ``[bold red]...[/bold red]``
5+
is stripped from string arguments so the plain output stays readable instead of
6+
leaking literal tags.
7+
"""
8+
9+
import re
310

411
try:
512
from rich import print as rich_print
@@ -8,13 +15,25 @@
815
RICH_AVAILABLE = False
916

1017

18+
# Matches rich markup tags like ``[bold red]``, ``[/bold red]``, ``[/]``, etc.
19+
# Conservative: only strips bracketed tokens that look like style directives
20+
# (letters, digits, slashes, spaces, # for hex colors).
21+
_RICH_MARKUP_RE = re.compile(r"\[/?[a-zA-Z0-9 #_/-]*\]")
22+
23+
24+
def _strip_markup(arg):
25+
if isinstance(arg, str):
26+
return _RICH_MARKUP_RE.sub("", arg)
27+
return arg
28+
29+
1130
def safe_print(*args, **kwargs):
12-
"""Try rich.print, fall back to regular print if it fails."""
31+
"""Try rich.print, fall back to regular print (with markup stripped)."""
1332
if RICH_AVAILABLE:
1433
try:
1534
rich_print(*args, **kwargs)
1635
return
1736
except Exception:
1837
pass
19-
# Fallback to regular print
20-
print(*args, **kwargs)
38+
# Fallback: strip rich markup so we don't print literal "[bold red]..." tags
39+
print(*(_strip_markup(a) for a in args), **kwargs)

tools/simple_terminal_tool.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def _execute_ssh_command(instance, command: str, timeout: Optional[int] = None)
196196
if ssh_context_manager:
197197
try:
198198
ssh_context_manager.__exit__(None, None, None)
199-
except:
199+
except Exception:
200200
pass
201201

202202
return {
@@ -210,7 +210,7 @@ def _execute_ssh_command(instance, command: str, timeout: Optional[int] = None)
210210
if ssh_context_manager:
211211
try:
212212
ssh_context_manager.__exit__(None, None, None)
213-
except:
213+
except Exception:
214214
pass
215215

216216
# Check if it's a timeout

0 commit comments

Comments
 (0)