From afc2d3acd46f8eff02e43bc5ceb481a84e8ba8b0 Mon Sep 17 00:00:00 2001 From: jo Date: Tue, 21 Nov 2023 15:48:34 +0100 Subject: [PATCH] feat: add client cached_session Use this cached session in short lived scopes. Prevent calling the API too often when iterating over a list of resource that contains the same objects. --- examples/use_cached_session.py | 19 ++++++++++++++ hcloud/_client.py | 45 +++++++++++++++++++++++++++++++++- tests/unit/test_client.py | 24 ++++++++++++++++++ 3 files changed, 87 insertions(+), 1 deletion(-) create mode 100644 examples/use_cached_session.py diff --git a/examples/use_cached_session.py b/examples/use_cached_session.py new file mode 100644 index 0000000..c750a30 --- /dev/null +++ b/examples/use_cached_session.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from os import environ + +from hcloud import Client + +assert ( + "HCLOUD_TOKEN" in environ +), "Please export your API token in the HCLOUD_TOKEN environment variable" +token = environ["HCLOUD_TOKEN"] + +client = Client(token=token) + +with client.cached_session() as session: + # This will query the API only once + for i in range(100): + servers = session.locations.get_all() + +print(servers) diff --git a/hcloud/_client.py b/hcloud/_client.py index a12f204..485c2ab 100644 --- a/hcloud/_client.py +++ b/hcloud/_client.py @@ -1,7 +1,9 @@ from __future__ import annotations +import copy import time -from typing import NoReturn +from contextlib import contextmanager +from typing import Generator, NoReturn import requests @@ -231,3 +233,44 @@ def request( # type: ignore[no-untyped-def] # TODO: return an empty dict instead of an empty string when content == "". return content # type: ignore[return-value] + + def session(self, session: requests.Session) -> None: + """ + Configure a custom :class:`Session ` to use when calling the API. + + :param session: The session to use when making API requests. + """ + self._requests_session = session + + @contextmanager + def cached_session(self) -> Generator[Client, None, None]: + """ + Provide a copy of the :class:`Client ` as context manager that + will cache all GET requests. + + Cached response will not expire automatically, therefor the cached client must + not be used for long living scopes. + """ + client = copy.deepcopy(self) + client.session(CachedSession()) + yield client + + +class CachedSession(requests.Session): + cache: dict[str, requests.Response] = {} + + def send(self, request: requests.PreparedRequest, **kwargs) -> requests.Response: # type: ignore[no-untyped-def] + """ + Send a given PreparedRequest. + """ + if request.method != "GET" or request.url is None: + return super().send(request, **kwargs) + + if request.url in self.cache: + return self.cache[request.url] + + response = super().send(request, **kwargs) + if response.ok: + self.cache[request.url] = response + + return response diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 36ac928..63ce2d8 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -1,12 +1,14 @@ from __future__ import annotations import json +from io import BytesIO from unittest.mock import MagicMock import pytest import requests from hcloud import APIException, Client +from hcloud._client import CachedSession class TestHetznerClient: @@ -182,3 +184,25 @@ def test_request_limit_then_success(self, client, rate_limit_response): "POST", "http://url.com", params={"argument": "value"}, timeout=2 ) assert client._requests_session.request.call_count == 2 + + +class TestCachedSession: + def test_cache(self): + response = requests.Response() + response.status_code = 200 + response.raw = BytesIO(json.dumps({"result": "data"}).encode("utf-8")) + + adapter = MagicMock() + adapter.send.return_value = response + + session = CachedSession() + session.get_adapter = MagicMock(return_value=adapter) + + resp1 = session.request("GET", "https://url.com", params={"argument": "value"}) + resp2 = session.request("GET", "https://url.com", params={"argument": "value"}) + + assert adapter.send.call_count == 1 + assert resp1 is session.cache["https://url.com/?argument=value"] + assert resp2 is session.cache["https://url.com/?argument=value"] + assert resp1.json() == {"result": "data"} + assert resp2.json() == {"result": "data"}