|
14 | 14 | # See the License for the specific language governing permissions and |
15 | 15 | # limitations under the License. |
16 | 16 | import abc |
17 | | -import inspect |
18 | 17 | import typing |
19 | 18 |
|
20 | 19 | from tenacity import _utils |
21 | 20 | from tenacity import retry_base |
22 | | -from tenacity import retry_if_exception as _retry_if_exception |
23 | | -from tenacity import retry_if_result as _retry_if_result |
24 | 21 |
|
25 | 22 | if typing.TYPE_CHECKING: |
26 | 23 | from tenacity import RetryCallState |
@@ -54,35 +51,48 @@ def __ror__( # type: ignore[misc,override] |
54 | 51 | return retry_any(other, self) |
55 | 52 |
|
56 | 53 |
|
57 | | -class async_predicate_mixin: |
58 | | - async def __call__(self, retry_state: "RetryCallState") -> bool: |
59 | | - result = super().__call__(retry_state) # type: ignore[misc] |
60 | | - if inspect.isawaitable(result): |
61 | | - result = await result |
62 | | - return typing.cast(bool, result) |
63 | | - |
64 | | - |
65 | 54 | RetryBaseT = typing.Union[ |
66 | 55 | async_retry_base, typing.Callable[["RetryCallState"], typing.Awaitable[bool]] |
67 | 56 | ] |
68 | 57 |
|
69 | 58 |
|
70 | | -class retry_if_exception(async_predicate_mixin, _retry_if_exception, async_retry_base): # type: ignore[misc] |
| 59 | +class retry_if_exception(async_retry_base): |
71 | 60 | """Retry strategy that retries if an exception verifies a predicate.""" |
72 | 61 |
|
73 | 62 | def __init__( |
74 | 63 | self, predicate: typing.Callable[[BaseException], typing.Awaitable[bool]] |
75 | 64 | ) -> None: |
76 | | - super().__init__(predicate) # type: ignore[arg-type] |
| 65 | + self.predicate = predicate |
| 66 | + |
| 67 | + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] |
| 68 | + if retry_state.outcome is None: |
| 69 | + raise RuntimeError("__call__() called before outcome was set") |
77 | 70 |
|
| 71 | + if retry_state.outcome.failed: |
| 72 | + exception = retry_state.outcome.exception() |
| 73 | + if exception is None: |
| 74 | + raise RuntimeError("outcome failed but the exception is None") |
| 75 | + return await self.predicate(exception) |
| 76 | + else: |
| 77 | + return False |
78 | 78 |
|
79 | | -class retry_if_result(async_predicate_mixin, _retry_if_result, async_retry_base): # type: ignore[misc] |
| 79 | + |
| 80 | +class retry_if_result(async_retry_base): |
80 | 81 | """Retries if the result verifies a predicate.""" |
81 | 82 |
|
82 | 83 | def __init__( |
83 | 84 | self, predicate: typing.Callable[[typing.Any], typing.Awaitable[bool]] |
84 | 85 | ) -> None: |
85 | | - super().__init__(predicate) # type: ignore[arg-type] |
| 86 | + self.predicate = predicate |
| 87 | + |
| 88 | + async def __call__(self, retry_state: "RetryCallState") -> bool: # type: ignore[override] |
| 89 | + if retry_state.outcome is None: |
| 90 | + raise RuntimeError("__call__() called before outcome was set") |
| 91 | + |
| 92 | + if not retry_state.outcome.failed: |
| 93 | + return await self.predicate(retry_state.outcome.result()) |
| 94 | + else: |
| 95 | + return False |
86 | 96 |
|
87 | 97 |
|
88 | 98 | class retry_any(async_retry_base): |
|
0 commit comments