From ebd29dca2ad266dab149f0b821c87056851501fa Mon Sep 17 00:00:00 2001
From: jo <ljonas@riseup.net>
Date: Tue, 30 Jul 2024 19:27:09 +0200
Subject: [PATCH] feat: implement retry policy

---
 hcloud/_client.py         | 100 +++++++++++++++++++++++---------------
 tests/unit/test_client.py |  48 +++++++++++++++---
 2 files changed, 103 insertions(+), 45 deletions(-)

diff --git a/hcloud/_client.py b/hcloud/_client.py
index d198990..105b7b9 100644
--- a/hcloud/_client.py
+++ b/hcloud/_client.py
@@ -1,6 +1,7 @@
 from __future__ import annotations
 
 import time
+from http import HTTPStatus
 from random import uniform
 from typing import Protocol
 
@@ -256,50 +257,71 @@ def request(  # type: ignore[no-untyped-def]
 
         retries = 0
         while True:
-            response = self._requests_session.request(
-                method=method,
-                url=url,
-                headers=headers,
-                **kwargs,
-            )
-
-            correlation_id = response.headers.get("X-Correlation-Id")
-            payload = {}
             try:
-                if len(response.content) > 0:
-                    payload = response.json()
-            except (TypeError, ValueError) as exc:
-                raise APIException(
-                    code=response.status_code,
-                    message=response.reason,
-                    details={"content": response.content},
-                    correlation_id=correlation_id,
-                ) from exc
-
-            if not response.ok:
-                if not payload or "error" not in payload:
-                    raise APIException(
-                        code=response.status_code,
-                        message=response.reason,
-                        details={"content": response.content},
-                        correlation_id=correlation_id,
-                    )
-
-                error: dict = payload["error"]
-
-                if (
-                    error["code"] == "rate_limit_exceeded"
-                    and retries < self._retry_max_retries
-                ):
+                response = self._requests_session.request(
+                    method=method,
+                    url=url,
+                    headers=headers,
+                    **kwargs,
+                )
+                return self._read_response(response)
+            except APIException as exception:
+                if retries < self._retry_max_retries and self._retry_policy(exception):
                     time.sleep(self._retry_interval(retries))
                     retries += 1
                     continue
-
+                raise
+            except requests.exceptions.Timeout:
+                if retries < self._retry_max_retries:
+                    time.sleep(self._retry_interval(retries))
+                    retries += 1
+                    continue
+                raise
+
+    def _read_response(self, response: requests.Response) -> dict:
+        correlation_id = response.headers.get("X-Correlation-Id")
+        payload = {}
+        try:
+            if len(response.content) > 0:
+                payload = response.json()
+        except (TypeError, ValueError) as exc:
+            raise APIException(
+                code=response.status_code,
+                message=response.reason,
+                details={"content": response.content},
+                correlation_id=correlation_id,
+            ) from exc
+
+        if not response.ok:
+            if not payload or "error" not in payload:
                 raise APIException(
-                    code=error["code"],
-                    message=error["message"],
-                    details=error.get("details"),
+                    code=response.status_code,
+                    message=response.reason,
+                    details={"content": response.content},
                     correlation_id=correlation_id,
                 )
 
-            return payload
+            error: dict = payload["error"]
+            raise APIException(
+                code=error["code"],
+                message=error["message"],
+                details=error.get("details"),
+                correlation_id=correlation_id,
+            )
+
+        return payload
+
+    def _retry_policy(self, exception: APIException) -> bool:
+        if isinstance(exception.code, str):
+            return exception.code in (
+                "rate_limit_exceeded",
+                "conflict",
+            )
+
+        if isinstance(exception.code, int):
+            return exception.code in (
+                HTTPStatus.BAD_GATEWAY,
+                HTTPStatus.GATEWAY_TIMEOUT,
+            )
+
+        return False
diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py
index 1e56bf2..4731336 100644
--- a/tests/unit/test_client.py
+++ b/tests/unit/test_client.py
@@ -109,12 +109,12 @@ def test_request_fails(self, client, fail_response):
 
     def test_request_fails_correlation_id(self, client, response):
         response.headers["X-Correlation-Id"] = "67ed842dc8bc8673"
-        response.status_code = 409
+        response.status_code = 422
         response._content = json.dumps(
             {
                 "error": {
-                    "code": "conflict",
-                    "message": "some conflict",
+                    "code": "service_error",
+                    "message": "Something crashed",
                 }
             }
         ).encode("utf-8")
@@ -125,11 +125,11 @@ def test_request_fails_correlation_id(self, client, response):
                 "POST", "http://url.com", params={"argument": "value"}, timeout=2
             )
         error = exception_info.value
-        assert error.code == "conflict"
-        assert error.message == "some conflict"
+        assert error.code == "service_error"
+        assert error.message == "Something crashed"
         assert error.details is None
         assert error.correlation_id == "67ed842dc8bc8673"
-        assert str(error) == "some conflict (conflict, 67ed842dc8bc8673)"
+        assert str(error) == "Something crashed (service_error, 67ed842dc8bc8673)"
 
     def test_request_500(self, client, fail_response):
         fail_response.status_code = 500
@@ -208,6 +208,42 @@ def test_request_limit_then_success(self, client, rate_limit_response):
         )
         assert client._requests_session.request.call_count == 2
 
+    @pytest.mark.parametrize(
+        ("exception", "expected"),
+        [
+            (
+                APIException(code="rate_limit_exceeded", message="Error", details=None),
+                True,
+            ),
+            (
+                APIException(code="conflict", message="Error", details=None),
+                True,
+            ),
+            (
+                APIException(code=409, message="Conflict", details=None),
+                False,
+            ),
+            (
+                APIException(code=429, message="Too Many Requests", details=None),
+                False,
+            ),
+            (
+                APIException(code=502, message="Bad Gateway", details=None),
+                True,
+            ),
+            (
+                APIException(code=503, message="Service Unavailable", details=None),
+                False,
+            ),
+            (
+                APIException(code=504, message="Gateway Timeout", details=None),
+                True,
+            ),
+        ],
+    )
+    def test_retry_policy(self, client, exception, expected):
+        assert client._retry_policy(exception) == expected
+
 
 def test_constant_backoff_function():
     backoff = constant_backoff_function(interval=1.0)