Skip to content

Commit 49cc9bf

Browse files
committed
fix(custom): preserve retryability through CustomColumnGenerator wrap
A real-workload run of #575 showed the early-shutdown gate still trips even with the gate-exclusion fix in place: the trigger is 10 timeouts inside Anonymizer's QA-repair custom columns, all wrapped in CustomColumnGenerationError (non-retryable) by the catch-all in CustomColumnGenerator. Two fixes here: 1. Re-raise RETRYABLE_MODEL_ERRORS unchanged before the wrap so the scheduler's _is_retryable correctly classifies them. 2. Surface _AsyncBridgedModelFacade timeouts as ModelTimeoutError instead of stdlib TimeoutError. Without this the sync bridge times out as the wrong exception type and is still classified non-retryable even after fix #1. Also moves _RETRYABLE_MODEL_ERRORS from async_scheduler to models/errors as the public RETRYABLE_MODEL_ERRORS tuple - both the scheduler and the wrap site need it, and models/errors is the appropriate home alongside the error class definitions. Refs #575.
1 parent 25d53b6 commit 49cc9bf

4 files changed

Lines changed: 131 additions & 15 deletions

File tree

packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from data_designer.config.column_configs import CustomColumnConfig, GenerationStrategy
1616
from data_designer.engine.column_generators.generators.base import SYNC_BRIDGE_TIMEOUT, ColumnGenerator
1717
from data_designer.engine.column_generators.utils.errors import CustomColumnGenerationError
18+
from data_designer.engine.models.errors import RETRYABLE_MODEL_ERRORS, ModelTimeoutError
1819
from data_designer.logging import LOG_INDENT
1920

2021
if TYPE_CHECKING:
@@ -69,7 +70,8 @@ def generate(self, *args: Any, **kwargs: Any) -> tuple[Any, list]:
6970
except concurrent.futures.TimeoutError as exc:
7071
future.cancel()
7172
logger.warning("Async model bridge timed out after %ss; coroutine cancelled", SYNC_BRIDGE_TIMEOUT)
72-
raise TimeoutError(f"model.generate() bridge timed out after {SYNC_BRIDGE_TIMEOUT}s") from exc
73+
# Raise as ModelTimeoutError so the scheduler classifies it retryable.
74+
raise ModelTimeoutError(f"model.generate() bridge timed out after {SYNC_BRIDGE_TIMEOUT}s") from exc
7375

7476
def __getattr__(self, name: str) -> Any:
7577
return getattr(object.__getattribute__(self, "_facade"), name)
@@ -147,6 +149,10 @@ async def agenerate(self, data: dict | pd.DataFrame) -> dict | pd.DataFrame | li
147149
result = await self._ainvoke_generator_function(data)
148150
except CustomColumnGenerationError:
149151
raise
152+
except RETRYABLE_MODEL_ERRORS:
153+
# Preserve retryability so the scheduler can salvage these
154+
# instead of counting them toward the early-shutdown gate.
155+
raise
150156
except Exception as e:
151157
logger.warning(
152158
f"⚠️ Custom generator function {self.config.generator_function.__name__!r} "
@@ -193,6 +199,10 @@ def _generate(self, data: dict | pd.DataFrame, is_dataframe: bool) -> dict | pd.
193199
result = self._invoke_generator_function(data)
194200
except CustomColumnGenerationError:
195201
raise
202+
except RETRYABLE_MODEL_ERRORS:
203+
# Preserve retryability so the scheduler can salvage these
204+
# instead of counting them toward the early-shutdown gate.
205+
raise
196206
except Exception as e:
197207
if not is_dataframe:
198208
logger.warning(

packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,7 @@
3030
)
3131
from data_designer.engine.dataset_builders.utils.sticky_progress_bar import StickyProgressBar
3232
from data_designer.engine.dataset_builders.utils.task_model import SliceRef, Task, TaskTrace
33-
from data_designer.engine.models.errors import (
34-
ModelAPIConnectionError,
35-
ModelInternalServerError,
36-
ModelRateLimitError,
37-
ModelTimeoutError,
38-
)
33+
from data_designer.engine.models.errors import RETRYABLE_MODEL_ERRORS
3934

4035
if TYPE_CHECKING:
4136
from data_designer.engine.column_generators.generators.base import ColumnGenerator
@@ -54,13 +49,6 @@
5449
DEGRADED_WARN_WINDOW: int = 20
5550
DEGRADED_WARN_INTERVAL_S: float = 60.0
5651

57-
_RETRYABLE_MODEL_ERRORS = (
58-
ModelRateLimitError,
59-
ModelTimeoutError,
60-
ModelInternalServerError,
61-
ModelAPIConnectionError,
62-
)
63-
6452

6553
class TrackingSemaphore(asyncio.Semaphore):
6654
"""``asyncio.Semaphore`` subclass that exposes available permits publicly."""
@@ -1036,7 +1024,7 @@ def get_semaphore_permits(self) -> tuple[int, int]:
10361024
@staticmethod
10371025
def _is_retryable(exc: Exception) -> bool:
10381026
"""Classify whether an exception is retryable."""
1039-
return isinstance(exc, _RETRYABLE_MODEL_ERRORS)
1027+
return isinstance(exc, RETRYABLE_MODEL_ERRORS)
10401028

10411029

10421030
def build_llm_bound_lookup(generators: dict[str, ColumnGenerator]) -> dict[str, bool]:

packages/data-designer-engine/src/data_designer/engine/models/errors.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,17 @@ def __init__(
131131
class ImageGenerationError(DataDesignerError): ...
132132

133133

134+
# Errors that the async scheduler defers to salvage instead of failing the run.
135+
# Callers that wrap arbitrary exceptions (e.g. CustomColumnGenerator) should
136+
# re-raise these unchanged so retryability is preserved through the wrap.
137+
RETRYABLE_MODEL_ERRORS: tuple[type[Exception], ...] = (
138+
ModelRateLimitError,
139+
ModelTimeoutError,
140+
ModelInternalServerError,
141+
ModelAPIConnectionError,
142+
)
143+
144+
134145
class FormattedLLMErrorMessage(BaseModel):
135146
cause: str
136147
solution: str

packages/data-designer-engine/tests/engine/column_generators/generators/test_custom.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@
2020
from data_designer.config.custom_column import custom_column_generator
2121
from data_designer.engine.column_generators.generators.custom import CustomColumnGenerator
2222
from data_designer.engine.column_generators.utils.errors import CustomColumnGenerationError
23+
from data_designer.engine.models.errors import (
24+
ModelAPIConnectionError,
25+
ModelInternalServerError,
26+
ModelRateLimitError,
27+
ModelTimeoutError,
28+
)
2329
from data_designer.engine.resources.resource_provider import ResourceProvider
2430

2531

@@ -350,6 +356,53 @@ def failing_generator(row: dict) -> dict:
350356
assert "something broke" in caplog.text
351357

352358

359+
@pytest.mark.parametrize(
360+
"exc_factory",
361+
[
362+
pytest.param(lambda: ModelRateLimitError("429"), id="rate_limit"),
363+
pytest.param(lambda: ModelTimeoutError("timeout"), id="timeout"),
364+
pytest.param(lambda: ModelInternalServerError("503"), id="internal_server"),
365+
pytest.param(lambda: ModelAPIConnectionError("conn reset"), id="api_connection"),
366+
],
367+
)
368+
def test_retryable_model_errors_pass_through_sync_wrap(exc_factory: Any) -> None:
369+
"""Retryable model errors raised inside a sync generator must NOT be wrapped.
370+
371+
Without this, the scheduler classifies the wrapped error as non-retryable and
372+
counts it toward the early-shutdown gate (regression seen in #575 follow-up).
373+
"""
374+
375+
@custom_column_generator()
376+
def raising_gen(row: dict) -> dict:
377+
raise exc_factory()
378+
379+
generator = _create_test_generator(name="result", generator_function=raising_gen)
380+
with pytest.raises(type(exc_factory())):
381+
generator.generate({"input": 1})
382+
383+
384+
@pytest.mark.parametrize(
385+
"exc_factory",
386+
[
387+
pytest.param(lambda: ModelRateLimitError("429"), id="rate_limit"),
388+
pytest.param(lambda: ModelTimeoutError("timeout"), id="timeout"),
389+
pytest.param(lambda: ModelInternalServerError("503"), id="internal_server"),
390+
pytest.param(lambda: ModelAPIConnectionError("conn reset"), id="api_connection"),
391+
],
392+
)
393+
@pytest.mark.asyncio
394+
async def test_retryable_model_errors_pass_through_async_wrap(exc_factory: Any) -> None:
395+
"""Retryable errors raised inside an async user generator must propagate unchanged."""
396+
397+
@custom_column_generator()
398+
async def raising_gen(row: dict) -> dict:
399+
raise exc_factory()
400+
401+
generator = _create_test_generator(name="result", generator_function=raising_gen)
402+
with pytest.raises(type(exc_factory())):
403+
await generator.agenerate({"input": 1})
404+
405+
353406
def test_undeclared_columns_removed_with_warning(caplog: pytest.LogCaptureFixture) -> None:
354407
"""Test that undeclared columns are removed with a warning."""
355408
import logging
@@ -555,6 +608,60 @@ def test_non_client_mode_errors_propagate(self) -> None:
555608
with pytest.raises(RuntimeError, match="connection timed out"):
556609
proxy.generate(prompt="hello")
557610

611+
def test_bridge_timeout_raises_model_timeout_error(self) -> None:
612+
"""A bridge timeout must surface as ModelTimeoutError so the scheduler sees it as retryable."""
613+
import asyncio
614+
import concurrent.futures
615+
import threading
616+
from unittest.mock import patch
617+
618+
from data_designer.engine.column_generators.generators.custom import _AsyncBridgedModelFacade
619+
from data_designer.engine.models.clients.errors import SyncClientUnavailableError
620+
621+
facade = Mock()
622+
facade.generate.side_effect = SyncClientUnavailableError(
623+
"Sync methods are not available on an async-mode HttpModelClient."
624+
)
625+
626+
async def hangs_forever(*args: Any, **kwargs: Any) -> tuple:
627+
await asyncio.sleep(60)
628+
return ("never", [], {})
629+
630+
facade.agenerate = hangs_forever
631+
proxy = _AsyncBridgedModelFacade(facade)
632+
633+
engine_loop = asyncio.new_event_loop()
634+
engine_thread = threading.Thread(target=engine_loop.run_forever, daemon=True)
635+
engine_thread.start()
636+
637+
try:
638+
with (
639+
patch(
640+
"data_designer.engine.dataset_builders.utils.async_concurrency.ensure_async_engine_loop",
641+
return_value=engine_loop,
642+
),
643+
patch("data_designer.engine.column_generators.generators.custom.SYNC_BRIDGE_TIMEOUT", 0.05),
644+
pytest.raises(ModelTimeoutError, match="bridge timed out"),
645+
):
646+
proxy.generate("hello")
647+
# Sanity: the same condition should not raise stdlib TimeoutError.
648+
with (
649+
patch(
650+
"data_designer.engine.dataset_builders.utils.async_concurrency.ensure_async_engine_loop",
651+
return_value=engine_loop,
652+
),
653+
patch("data_designer.engine.column_generators.generators.custom.SYNC_BRIDGE_TIMEOUT", 0.05),
654+
):
655+
try:
656+
proxy.generate("hello2")
657+
except ModelTimeoutError:
658+
pass
659+
except concurrent.futures.TimeoutError:
660+
pytest.fail("bridge raised stdlib TimeoutError instead of ModelTimeoutError")
661+
finally:
662+
engine_loop.call_soon_threadsafe(engine_loop.stop)
663+
engine_thread.join(timeout=5)
664+
558665
def test_deadlock_guard_on_event_loop(self) -> None:
559666
"""Raises a clear error instead of deadlocking when called from the event loop."""
560667
import asyncio

0 commit comments

Comments
 (0)