From fe7ddf6da78f8dbbc395eb98ff1200b8117f0cc0 Mon Sep 17 00:00:00 2001 From: "Jonas L." Date: Mon, 22 Jul 2024 15:34:48 +0200 Subject: [PATCH] feat: add exponential and constant backoff function (#416) - Implement the same backoff function as in the hcloud-go libary - Preparation work to change the retry backoff function to use an exponential backoff interval. - Rename PollIntervalFunction to BackoffFunction, as it is not only used for polling. --- hcloud/__init__.py | 6 ++++- hcloud/_client.py | 48 ++++++++++++++++++++++++++++++++++++--- tests/unit/test_client.py | 28 ++++++++++++++++++++++- 3 files changed, 77 insertions(+), 5 deletions(-) diff --git a/hcloud/__init__.py b/hcloud/__init__.py index 2cf9921..95709bb 100644 --- a/hcloud/__init__.py +++ b/hcloud/__init__.py @@ -1,6 +1,10 @@ from __future__ import annotations -from ._client import Client as Client # noqa pylint: disable=C0414 +from ._client import ( # noqa pylint: disable=C0414 + Client as Client, + constant_backoff_function as constant_backoff_function, + exponential_backoff_function as exponential_backoff_function, +) from ._exceptions import ( # noqa pylint: disable=C0414 APIException as APIException, HCloudException as HCloudException, diff --git a/hcloud/_client.py b/hcloud/_client.py index a0de13a..3ad9642 100644 --- a/hcloud/_client.py +++ b/hcloud/_client.py @@ -1,6 +1,7 @@ from __future__ import annotations import time +from random import uniform from typing import Protocol import requests @@ -26,7 +27,7 @@ from .volumes import VolumesClient -class PollIntervalFunction(Protocol): +class BackoffFunction(Protocol): def __call__(self, retries: int) -> float: """ Return a interval in seconds to wait between each API call. @@ -35,6 +36,47 @@ def __call__(self, retries: int) -> float: """ +def constant_backoff_function(interval: float) -> BackoffFunction: + """ + Return a backoff function, implementing a constant backoff. + + :param interval: Constant interval to return. + """ + + # pylint: disable=unused-argument + def func(retries: int) -> float: + return interval + + return func + + +def exponential_backoff_function( + *, + base: float, + multiplier: int, + cap: float, + jitter: bool = False, +) -> BackoffFunction: + """ + Return a backoff function, implementing a truncated exponential backoff with + optional full jitter. + + :param base: Base for the exponential backoff algorithm. + :param multiplier: Multiplier for the exponential backoff algorithm. + :param cap: Value at which the interval is truncated. + :param jitter: Whether to add jitter. + """ + + def func(retries: int) -> float: + interval = base * multiplier**retries # Exponential backoff + interval = min(cap, interval) # Cap backoff + if jitter: + interval = uniform(base, interval) # Add jitter + return interval + + return func + + class Client: """Base Client for accessing the Hetzner Cloud API""" @@ -48,7 +90,7 @@ def __init__( api_endpoint: str = "https://api.hetzner.cloud/v1", application_name: str | None = None, application_version: str | None = None, - poll_interval: int | float | PollIntervalFunction = 1.0, + poll_interval: int | float | BackoffFunction = 1.0, poll_max_retries: int = 120, timeout: float | tuple[float, float] | None = None, ): @@ -73,7 +115,7 @@ def __init__( self._requests_timeout = timeout if isinstance(poll_interval, (int, float)): - self._poll_interval_func = lambda _: poll_interval # Constant poll interval + self._poll_interval_func = constant_backoff_function(poll_interval) else: self._poll_interval_func = poll_interval self._poll_max_retries = poll_max_retries diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index e258a51..66df7ee 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -6,7 +6,12 @@ import pytest import requests -from hcloud import APIException, Client +from hcloud import ( + APIException, + Client, + constant_backoff_function, + exponential_backoff_function, +) class TestHetznerClient: @@ -202,3 +207,24 @@ def test_request_limit_then_success(self, client, rate_limit_response): "POST", "http://url.com", params={"argument": "value"}, timeout=2 ) assert client._requests_session.request.call_count == 2 + + +def test_constant_backoff_function(): + backoff = constant_backoff_function(interval=1.0) + max_retries = 5 + + for i in range(max_retries): + assert backoff(i) == 1.0 + + +def test_exponential_backoff_function(): + backoff = exponential_backoff_function( + base=1.0, + multiplier=2, + cap=60.0, + ) + max_retries = 5 + + results = [backoff(i) for i in range(max_retries)] + assert sum(results) == 31.0 + assert results == [1.0, 2.0, 4.0, 8.0, 16.0]