Skip to content

Commit

Permalink
Incrementally build iter actions list (#434)
Browse files Browse the repository at this point in the history
* Incrementally build iter actions list

* Add TypedDict for iter_state

* Format with ruff

* Make IterState a dataclass

* Fix typing

* Conditionally add slots to dataclass

---------

Co-authored-by: Julien Danjou <[email protected]>
  • Loading branch information
hasier and jd authored Feb 6, 2024
1 parent 24b4a5c commit 17aefd9
Showing 1 changed file with 100 additions and 26 deletions.
126 changes: 100 additions & 26 deletions tenacity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import dataclasses
import functools
import sys
import threading
Expand Down Expand Up @@ -97,6 +96,29 @@
WrappedFn = t.TypeVar("WrappedFn", bound=t.Callable[..., t.Any])


dataclass_kwargs = {}
if sys.version_info >= (3, 10):
dataclass_kwargs.update({"slots": True})


@dataclasses.dataclass(**dataclass_kwargs)
class IterState:
actions: t.List[t.Callable[["RetryCallState"], t.Any]] = dataclasses.field(
default_factory=list
)
retry_run_result: bool = False
delay_since_first_attempt: int = 0
stop_run_result: bool = False
is_explicit_retry: bool = False

def reset(self) -> None:
self.actions = []
self.retry_run_result = False
self.delay_since_first_attempt = 0
self.stop_run_result = False
self.is_explicit_retry = False


class TryAgain(Exception):
"""Always retry the executed function when raised."""

Expand Down Expand Up @@ -287,6 +309,14 @@ def statistics(self) -> t.Dict[str, t.Any]:
self._local.statistics = t.cast(t.Dict[str, t.Any], {})
return self._local.statistics

@property
def iter_state(self) -> IterState:
try:
return self._local.iter_state # type: ignore[no-any-return]
except AttributeError:
self._local.iter_state = IterState()
return self._local.iter_state

def wraps(self, f: WrappedFn) -> WrappedFn:
"""Wrap a function for retrying.
Expand All @@ -313,45 +343,89 @@ def begin(self) -> None:
self.statistics["attempt_number"] = 1
self.statistics["idle_for"] = 0

def iter(self, retry_state: "RetryCallState") -> t.Union[DoAttempt, DoSleep, t.Any]: # noqa
fut = retry_state.outcome
if fut is None:
if self.before is not None:
self.before(retry_state)
return DoAttempt()

is_explicit_retry = fut.failed and isinstance(fut.exception(), TryAgain)
if not (is_explicit_retry or self.retry(retry_state)):
return fut.result()
def _add_action_func(self, fn: t.Callable[..., t.Any]) -> None:
self.iter_state.actions.append(fn)

if self.after is not None:
self.after(retry_state)
def _run_retry(self, retry_state: "RetryCallState") -> None:
self.iter_state.retry_run_result = self.retry(retry_state)

def _run_wait(self, retry_state: "RetryCallState") -> None:
if self.wait:
sleep = self.wait(retry_state)
else:
sleep = 0.0

retry_state.upcoming_sleep = sleep

def _run_stop(self, retry_state: "RetryCallState") -> None:
self.statistics["delay_since_first_attempt"] = retry_state.seconds_since_start
if self.stop(retry_state):
self.iter_state.stop_run_result = self.stop(retry_state)

def iter(self, retry_state: "RetryCallState") -> t.Union[DoAttempt, DoSleep, t.Any]: # noqa
self._begin_iter(retry_state)
result = None
for action in self.iter_state.actions:
result = action(retry_state)
return result

def _begin_iter(self, retry_state: "RetryCallState") -> None: # noqa
self.iter_state.reset()

fut = retry_state.outcome
if fut is None:
if self.before is not None:
self._add_action_func(self.before)
self._add_action_func(lambda rs: DoAttempt())
return

self.iter_state.is_explicit_retry = fut.failed and isinstance(
fut.exception(), TryAgain
)
if not self.iter_state.is_explicit_retry:
self._add_action_func(self._run_retry)
self._add_action_func(self._post_retry_check_actions)

def _post_retry_check_actions(self, retry_state: "RetryCallState") -> None:
if not (self.iter_state.is_explicit_retry or self.iter_state.retry_run_result):
self._add_action_func(lambda rs: rs.outcome.result())
return

if self.after is not None:
self._add_action_func(self.after)

self._add_action_func(self._run_wait)
self._add_action_func(self._run_stop)
self._add_action_func(self._post_stop_check_actions)

def _post_stop_check_actions(self, retry_state: "RetryCallState") -> None:
if self.iter_state.stop_run_result:
if self.retry_error_callback:
return self.retry_error_callback(retry_state)
retry_exc = self.retry_error_cls(fut)
if self.reraise:
raise retry_exc.reraise()
raise retry_exc from fut.exception()
self._add_action_func(self.retry_error_callback)
return

def exc_check(rs: "RetryCallState") -> None:
fut = t.cast(Future, rs.outcome)
retry_exc = self.retry_error_cls(fut)
if self.reraise:
raise retry_exc.reraise()
raise retry_exc from fut.exception()

self._add_action_func(exc_check)
return

def next_action(rs: "RetryCallState") -> None:
sleep = rs.upcoming_sleep
rs.next_action = RetryAction(sleep)
rs.idle_for += sleep
self.statistics["idle_for"] += sleep
self.statistics["attempt_number"] += 1

retry_state.next_action = RetryAction(sleep)
retry_state.idle_for += sleep
self.statistics["idle_for"] += sleep
self.statistics["attempt_number"] += 1
self._add_action_func(next_action)

if self.before_sleep is not None:
self.before_sleep(retry_state)
self._add_action_func(self.before_sleep)

return DoSleep(sleep)
self._add_action_func(lambda rs: DoSleep(rs.upcoming_sleep))

def __iter__(self) -> t.Generator[AttemptManager, None, None]:
self.begin()
Expand Down

0 comments on commit 17aefd9

Please sign in to comment.