Skip to content

Commit a367de8

Browse files
Add --parallel N flag to vera-bench run
Run N problems concurrently via ThreadPoolExecutor. Each worker is I/O-bound on its LLM HTTP call + subprocess-based check/run, so the GIL is not a bottleneck. Use case: slow models like Kimi K2.5 averaged 49s/problem sequentially across the 60-problem AILANG sweep (~50 min total). With --parallel 10 the same sweep should drop to ~5 min, which makes release-time re-evals practical. Implementation: - ThreadPoolExecutor with max_workers=parallel - Per-problem futures collected via as_completed - threading.Lock around the JSONL append so concurrent writes don't interleave. Lines are still self-contained (carry problem_id) so completion-order writes are fine. - Workers share the same work_dir; per-problem temp files are uniquified by problem_id (existing behavior). - Exception per worker is caught and logged; the sweep continues. Default parallel=1 preserves the existing sequential path with no behavior change. Smoke-tested with claude-haiku-4-5 --tier 1 --parallel 4: 10/10 problems, no duplicates, 100%/100% run_correct. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
1 parent 5ad6475 commit a367de8

2 files changed

Lines changed: 85 additions & 13 deletions

File tree

vera_bench/cli.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,17 @@ def validate(problems_dir: Path | None, solutions_dir: Path | None):
7676
is_flag=True,
7777
help="Keep temporary generated files",
7878
)
79+
@click.option(
80+
"--parallel",
81+
type=int,
82+
default=1,
83+
show_default=True,
84+
help=(
85+
"Run N problems concurrently via ThreadPoolExecutor. "
86+
"Use >1 for slow models (e.g. Kimi K2.5). "
87+
"Each worker is I/O-bound on its LLM call + subprocess runs."
88+
),
89+
)
7990
def run(
8091
model: str,
8192
tier: int | None,
@@ -86,6 +97,7 @@ def run(
8697
output_dir: Path | None,
8798
max_tokens: int,
8899
keep_temps: bool,
100+
parallel: int,
89101
):
90102
"""Run benchmark against an LLM model."""
91103
from vera_bench.metrics import compute_metrics
@@ -274,6 +286,7 @@ def _ver_slug(v: str) -> str:
274286
keep_temps=keep_temps,
275287
bench_version=bench_ver,
276288
vera_version=vera_ver,
289+
parallel=parallel,
277290
)
278291

279292
# Print summary

vera_bench/runner.py

Lines changed: 72 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1108,20 +1108,62 @@ def run_benchmark(
11081108
keep_temps: bool = False,
11091109
bench_version: str = "",
11101110
vera_version: str = "",
1111+
parallel: int = 1,
11111112
) -> list[ProblemResult]:
11121113
"""Run the full benchmark across all problems.
11131114
11141115
Results are written to JSONL incrementally (survives crashes).
1116+
1117+
When ``parallel > 1``, problems are dispatched to a ThreadPoolExecutor
1118+
with ``parallel`` workers. Each problem runs independently (its own
1119+
LLM call, its own subprocess-based check/run), so threads only block
1120+
on I/O (HTTP to the LLM provider, subprocess spawns to the toolchain).
1121+
The GIL is not a bottleneck. Use this when sweeping slow models —
1122+
e.g. Kimi K2.5 at ~50s/problem sequential becomes ~5s/problem with
1123+
parallel=10.
1124+
1125+
JSONL output ordering is by completion order, not by problem index,
1126+
when running in parallel. Each line is self-contained (carries
1127+
``problem_id``) so downstream consumers can sort if needed.
11151128
"""
11161129
work_dir = Path(tempfile.mkdtemp(prefix="verabench_"))
11171130
all_results: list[ProblemResult] = []
11181131

11191132
try:
1120-
with Progress(console=console) as progress:
1121-
task = progress.add_task("Running benchmark...", total=len(problems))
1122-
for problem in problems:
1123-
problem_results = run_single_problem(
1124-
problem=problem,
1133+
if parallel <= 1:
1134+
with Progress(console=console) as progress:
1135+
task = progress.add_task("Running benchmark...", total=len(problems))
1136+
for problem in problems:
1137+
problem_results = run_single_problem(
1138+
problem=problem,
1139+
client=client,
1140+
skill_md=skill_md,
1141+
vera=vera,
1142+
work_dir=work_dir,
1143+
mode=mode,
1144+
language=language,
1145+
max_fix_attempts=max_fix_attempts,
1146+
max_tokens=max_tokens,
1147+
bench_version=bench_version,
1148+
vera_version=vera_version,
1149+
)
1150+
all_results.extend(problem_results)
1151+
1152+
if output_path:
1153+
with open(output_path, "a", encoding="utf-8") as f:
1154+
for r in problem_results:
1155+
f.write(r.to_jsonl() + "\n")
1156+
1157+
progress.advance(task)
1158+
else:
1159+
import threading
1160+
from concurrent.futures import ThreadPoolExecutor, as_completed
1161+
1162+
write_lock = threading.Lock()
1163+
1164+
def _run_one(p: dict) -> list[ProblemResult]:
1165+
return run_single_problem(
1166+
problem=p,
11251167
client=client,
11261168
skill_md=skill_md,
11271169
vera=vera,
@@ -1133,15 +1175,32 @@ def run_benchmark(
11331175
bench_version=bench_version,
11341176
vera_version=vera_version,
11351177
)
1136-
all_results.extend(problem_results)
1137-
1138-
# Write JSONL incrementally
1139-
if output_path:
1140-
with open(output_path, "a", encoding="utf-8") as f:
1141-
for r in problem_results:
1142-
f.write(r.to_jsonl() + "\n")
11431178

1144-
progress.advance(task)
1179+
with Progress(console=console) as progress:
1180+
task = progress.add_task(
1181+
f"Running benchmark (parallel={parallel})...",
1182+
total=len(problems),
1183+
)
1184+
with ThreadPoolExecutor(max_workers=parallel) as executor:
1185+
futures = {executor.submit(_run_one, p): p for p in problems}
1186+
for fut in as_completed(futures):
1187+
try:
1188+
problem_results = fut.result()
1189+
except Exception as exc: # noqa: BLE001
1190+
pid = futures[fut].get("id", "?")
1191+
console.print(
1192+
f"[red]Worker failed on {pid}: {exc}[/red]"
1193+
)
1194+
progress.advance(task)
1195+
continue
1196+
all_results.extend(problem_results)
1197+
if output_path:
1198+
with write_lock, open(
1199+
output_path, "a", encoding="utf-8"
1200+
) as f:
1201+
for r in problem_results:
1202+
f.write(r.to_jsonl() + "\n")
1203+
progress.advance(task)
11451204
finally:
11461205
if not keep_temps:
11471206
shutil.rmtree(work_dir, ignore_errors=True)

0 commit comments

Comments
 (0)