diff --git a/hcloud/_client.py b/hcloud/_client.py index d198990..105b7b9 100644 --- a/hcloud/_client.py +++ b/hcloud/_client.py @@ -1,6 +1,7 @@ from __future__ import annotations import time +from http import HTTPStatus from random import uniform from typing import Protocol @@ -256,50 +257,71 @@ def request( # type: ignore[no-untyped-def] retries = 0 while True: - response = self._requests_session.request( - method=method, - url=url, - headers=headers, - **kwargs, - ) - - correlation_id = response.headers.get("X-Correlation-Id") - payload = {} try: - if len(response.content) > 0: - payload = response.json() - except (TypeError, ValueError) as exc: - raise APIException( - code=response.status_code, - message=response.reason, - details={"content": response.content}, - correlation_id=correlation_id, - ) from exc - - if not response.ok: - if not payload or "error" not in payload: - raise APIException( - code=response.status_code, - message=response.reason, - details={"content": response.content}, - correlation_id=correlation_id, - ) - - error: dict = payload["error"] - - if ( - error["code"] == "rate_limit_exceeded" - and retries < self._retry_max_retries - ): + response = self._requests_session.request( + method=method, + url=url, + headers=headers, + **kwargs, + ) + return self._read_response(response) + except APIException as exception: + if retries < self._retry_max_retries and self._retry_policy(exception): time.sleep(self._retry_interval(retries)) retries += 1 continue - + raise + except requests.exceptions.Timeout: + if retries < self._retry_max_retries: + time.sleep(self._retry_interval(retries)) + retries += 1 + continue + raise + + def _read_response(self, response: requests.Response) -> dict: + correlation_id = response.headers.get("X-Correlation-Id") + payload = {} + try: + if len(response.content) > 0: + payload = response.json() + except (TypeError, ValueError) as exc: + raise APIException( + code=response.status_code, + message=response.reason, + details={"content": response.content}, + correlation_id=correlation_id, + ) from exc + + if not response.ok: + if not payload or "error" not in payload: raise APIException( - code=error["code"], - message=error["message"], - details=error.get("details"), + code=response.status_code, + message=response.reason, + details={"content": response.content}, correlation_id=correlation_id, ) - return payload + error: dict = payload["error"] + raise APIException( + code=error["code"], + message=error["message"], + details=error.get("details"), + correlation_id=correlation_id, + ) + + return payload + + def _retry_policy(self, exception: APIException) -> bool: + if isinstance(exception.code, str): + return exception.code in ( + "rate_limit_exceeded", + "conflict", + ) + + if isinstance(exception.code, int): + return exception.code in ( + HTTPStatus.BAD_GATEWAY, + HTTPStatus.GATEWAY_TIMEOUT, + ) + + return False diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 1e56bf2..4731336 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -109,12 +109,12 @@ def test_request_fails(self, client, fail_response): def test_request_fails_correlation_id(self, client, response): response.headers["X-Correlation-Id"] = "67ed842dc8bc8673" - response.status_code = 409 + response.status_code = 422 response._content = json.dumps( { "error": { - "code": "conflict", - "message": "some conflict", + "code": "service_error", + "message": "Something crashed", } } ).encode("utf-8") @@ -125,11 +125,11 @@ def test_request_fails_correlation_id(self, client, response): "POST", "http://url.com", params={"argument": "value"}, timeout=2 ) error = exception_info.value - assert error.code == "conflict" - assert error.message == "some conflict" + assert error.code == "service_error" + assert error.message == "Something crashed" assert error.details is None assert error.correlation_id == "67ed842dc8bc8673" - assert str(error) == "some conflict (conflict, 67ed842dc8bc8673)" + assert str(error) == "Something crashed (service_error, 67ed842dc8bc8673)" def test_request_500(self, client, fail_response): fail_response.status_code = 500 @@ -208,6 +208,42 @@ def test_request_limit_then_success(self, client, rate_limit_response): ) assert client._requests_session.request.call_count == 2 + @pytest.mark.parametrize( + ("exception", "expected"), + [ + ( + APIException(code="rate_limit_exceeded", message="Error", details=None), + True, + ), + ( + APIException(code="conflict", message="Error", details=None), + True, + ), + ( + APIException(code=409, message="Conflict", details=None), + False, + ), + ( + APIException(code=429, message="Too Many Requests", details=None), + False, + ), + ( + APIException(code=502, message="Bad Gateway", details=None), + True, + ), + ( + APIException(code=503, message="Service Unavailable", details=None), + False, + ), + ( + APIException(code=504, message="Gateway Timeout", details=None), + True, + ), + ], + ) + def test_retry_policy(self, client, exception, expected): + assert client._retry_policy(exception) == expected + def test_constant_backoff_function(): backoff = constant_backoff_function(interval=1.0)