Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add client cached_session #327

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions examples/use_cached_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from __future__ import annotations

from os import environ

from hcloud import Client

assert (
"HCLOUD_TOKEN" in environ
), "Please export your API token in the HCLOUD_TOKEN environment variable"
token = environ["HCLOUD_TOKEN"]

client = Client(token=token)

with client.cached_session() as session:
# This will query the API only once
for i in range(100):
locations = session.locations.get_all()

print(locations)
44 changes: 43 additions & 1 deletion hcloud/_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

import time
from typing import NoReturn
from contextlib import contextmanager
from typing import Generator, NoReturn

import requests

Expand Down Expand Up @@ -231,3 +232,44 @@ def request( # type: ignore[no-untyped-def]

# TODO: return an empty dict instead of an empty string when content == "".
return content # type: ignore[return-value]

def session(self, session: requests.Session) -> None:
"""
Configure a custom :class:`Session <requests.Session>` to use when calling the API.

:param session: The session to use when making API requests.
"""
self._requests_session = session

@contextmanager
def cached_session(self) -> Generator[Client, None, None]:
"""
Provide a copy of the :class:`Client <hcloud.Client>` as context manager that
will cache all GET requests.

Cached response will not expire, therefore the cached client must not be used
for long living scopes.
"""
self.session(CachedSession())
yield self
self.session(requests.Session())


class CachedSession(requests.Session):
cache: dict[str, requests.Response] = {}

def send(self, request: requests.PreparedRequest, **kwargs) -> requests.Response: # type: ignore[no-untyped-def]
"""
Send a given PreparedRequest.
"""
if request.method != "GET" or request.url is None:
return super().send(request, **kwargs)

if request.url in self.cache:
jooola marked this conversation as resolved.
Show resolved Hide resolved
return self.cache[request.url]

response = super().send(request, **kwargs)
if response.ok:
self.cache[request.url] = response

return response
24 changes: 24 additions & 0 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

import json
from io import BytesIO
from unittest.mock import MagicMock

import pytest
import requests

from hcloud import APIException, Client
from hcloud._client import CachedSession


class TestHetznerClient:
Expand Down Expand Up @@ -182,3 +184,25 @@ def test_request_limit_then_success(self, client, rate_limit_response):
"POST", "http://url.com", params={"argument": "value"}, timeout=2
)
assert client._requests_session.request.call_count == 2


class TestCachedSession:
def test_cache(self):
response = requests.Response()
response.status_code = 200
response.raw = BytesIO(json.dumps({"result": "data"}).encode("utf-8"))

adapter = MagicMock()
adapter.send.return_value = response

session = CachedSession()
session.get_adapter = MagicMock(return_value=adapter)

resp1 = session.request("GET", "https://url.com", params={"argument": "value"})
resp2 = session.request("GET", "https://url.com", params={"argument": "value"})

assert adapter.send.call_count == 1
assert resp1 is session.cache["https://url.com/?argument=value"]
assert resp2 is session.cache["https://url.com/?argument=value"]
assert resp1.json() == {"result": "data"}
assert resp2.json() == {"result": "data"}