Skip to content

Commit

Permalink
✨ Feature: add cache strategy option (#161)
Browse files Browse the repository at this point in the history
  • Loading branch information
yanyongyu authored Nov 15, 2024
1 parent 08d5034 commit 28b8b65
Show file tree
Hide file tree
Showing 8 changed files with 114 additions and 35 deletions.
7 changes: 7 additions & 0 deletions docs/usage/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ github = GitHub(
user_agent="GitHubKit/Python",
follow_redirects=True,
timeout=None,
cache_strategy=None,
http_cache=True,
auto_retry=True,
rest_api_validate_body=True,
Expand All @@ -24,13 +25,15 @@ Or, you can pass the config object directly (not recommended):
import httpx
from githubkit import GitHub, Config
from githubkit.retry import RETRY_DEFAULT
from githubkit.cache import DEFAULT_CACHE_STRATEGY

config = Config(
base_url="https://api.github.com/",
accept="application/vnd.github+json",
user_agent="GitHubKit/Python",
follow_redirects=True,
timeout=httpx.Timeout(None),
cache_strategy=DEFAULT_CACHE_STRATEGY,
http_cache=True,
auto_retry=RETRY_DEFAULT,
rest_api_validate_body=True,
Expand Down Expand Up @@ -65,6 +68,10 @@ The `follow_redirects` option is used to enable or disable the HTTP redirect fol

The `timeout` option is used to set the request timeout. You can pass a float, `None` or `httpx.Timeout` to this field. By default, the requests will never timeout. See [Timeout](https://www.python-httpx.org/advanced/timeouts/) for more information.

### `cache_strategy`

The `cache_strategy` option defines how to cache the tokens or http responses. You can provide a githubkit built-in cache strategy or a custom one that implements the `BaseCacheStrategy` interface. By default, githubkit uses the `MemCacheStrategy` to cache the data in memory.

### `http_cache`

The `http_cache` option enables the http caching feature powered by [Hishel](https://hishel.com/) for HTTPX. GitHub API limits the number of requests that you can make within a specific amount of time. This feature is useful to reduce the number of requests to GitHub API and avoid hitting the rate limit.
Expand Down
35 changes: 16 additions & 19 deletions githubkit/auth/app.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Union, Optional
from typing_extensions import LiteralString
from datetime import datetime, timezone, timedelta
from collections.abc import Generator, AsyncGenerator
from typing import TYPE_CHECKING, Union, ClassVar, Optional

import httpx

from githubkit.exception import AuthCredentialError
from githubkit.cache import DEFAULT_CACHE, BaseCache
from githubkit.utils import UNSET, Unset, exclude_unset
from githubkit.compat import model_dump, type_validate_python

Expand Down Expand Up @@ -38,10 +38,9 @@ class AppAuth(httpx.Auth):
repositories: Union[Unset, list[str]] = UNSET
repository_ids: Union[Unset, list[int]] = UNSET
permissions: Union[Unset, "AppPermissionsType"] = UNSET
cache: "BaseCache" = DEFAULT_CACHE

JWT_CACHE_KEY = "githubkit:auth:app:{issuer}:jwt"
INSTALLATION_CACHE_KEY = (
JWT_CACHE_KEY: ClassVar[LiteralString] = "githubkit:auth:app:{issuer}:jwt"
INSTALLATION_CACHE_KEY: ClassVar[LiteralString] = (
"githubkit:auth:app:{issuer}:installation:"
"{installation_id}:{permissions}:{repositories}:{repository_ids}"
)
Expand Down Expand Up @@ -89,17 +88,19 @@ def _get_jwt_cache_key(self) -> str:
return self.JWT_CACHE_KEY.format(issuer=self.issuer)

def get_jwt(self) -> str:
cache = self.github.config.cache_strategy.get_cache_storage()
cache_key = self._get_jwt_cache_key()
if not (token := self.cache.get(cache_key)):
if not (token := cache.get(cache_key)):
token = self._create_jwt()
self.cache.set(cache_key, token, timedelta(minutes=8))
cache.set(cache_key, token, timedelta(minutes=8))
return token

async def aget_jwt(self) -> str:
cache = self.github.config.cache_strategy.get_async_cache_storage()
cache_key = self._get_jwt_cache_key()
if not (token := await self.cache.aget(cache_key)):
if not (token := await cache.aget(cache_key)):
token = self._create_jwt()
await self.cache.aset(cache_key, token, timedelta(minutes=8))
await cache.aset(cache_key, token, timedelta(minutes=8))
return token

def _build_installation_auth_request(self) -> httpx.Request:
Expand Down Expand Up @@ -202,8 +203,9 @@ def sync_auth_flow(
).sync_auth_flow(request)
return

cache = self.github.config.cache_strategy.get_cache_storage()
key = self._get_installation_cache_key()
if not (token := self.cache.get(key)):
if not (token := cache.get(key)):
token_request = self._build_installation_auth_request()
token_request.headers["Authorization"] = f"Bearer {self.get_jwt()}"
response = yield token_request
Expand All @@ -213,7 +215,7 @@ def sync_auth_flow(
expire = datetime.strptime(
response.parsed_data.expires_at, "%Y-%m-%dT%H:%M:%SZ"
).replace(tzinfo=timezone.utc) - datetime.now(timezone.utc)
self.cache.set(key, token, expire)
cache.set(key, token, expire)
request.headers["Authorization"] = f"token {token}"
yield request

Expand All @@ -239,8 +241,9 @@ async def async_auth_flow(
yield request
return

cache = self.github.config.cache_strategy.get_async_cache_storage()
key = self._get_installation_cache_key()
if not (token := await self.cache.aget(key)):
if not (token := await cache.aget(key)):
token_request = self._build_installation_auth_request()
token_request.headers["Authorization"] = f"Bearer {await self.aget_jwt()}"
response = yield token_request
Expand All @@ -250,7 +253,7 @@ async def async_auth_flow(
expire = datetime.strptime(
response.parsed_data.expires_at, "%Y-%m-%dT%H:%M:%SZ"
).replace(tzinfo=timezone.utc) - datetime.now(timezone.utc)
await self.cache.aset(key, token, expire)
await cache.aset(key, token, expire)
request.headers["Authorization"] = f"token {token}"
yield request

Expand All @@ -263,7 +266,6 @@ class AppAuthStrategy(BaseAuthStrategy):
private_key: str
client_id: Optional[str] = None
client_secret: Optional[str] = None
cache: "BaseCache" = DEFAULT_CACHE

def __post_init__(self):
# either app_id or client_id must be provided
Expand All @@ -288,7 +290,6 @@ def as_installation(
repositories,
repository_ids,
permissions,
self.cache,
)

def as_oauth_app(self) -> OAuthAppAuthStrategy:
Expand All @@ -305,7 +306,6 @@ def get_auth_flow(self, github: "GitHubCore") -> httpx.Auth:
self.private_key,
self.client_id,
self.client_secret,
cache=self.cache,
)


Expand All @@ -321,7 +321,6 @@ class AppInstallationAuthStrategy(BaseAuthStrategy):
repositories: Union[Unset, list[str]] = UNSET
repository_ids: Union[Unset, list[int]] = UNSET
permissions: Union[Unset, "AppPermissionsType"] = UNSET
cache: "BaseCache" = DEFAULT_CACHE

def __post_init__(self):
# either app_id or client_id must be provided
Expand All @@ -336,7 +335,6 @@ def as_app(self) -> AppAuthStrategy:
self.private_key,
self.client_id,
self.client_secret,
self.cache,
)

def get_auth_flow(self, github: "GitHubCore") -> httpx.Auth:
Expand All @@ -350,5 +348,4 @@ def get_auth_flow(self, github: "GitHubCore") -> httpx.Auth:
self.repositories,
self.repository_ids,
self.permissions,
cache=self.cache,
)
5 changes: 4 additions & 1 deletion githubkit/cache/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from .base import BaseCache as BaseCache
from .mem_cache import MemCache as MemCache
from .base import AsyncBaseCache as AsyncBaseCache
from .base import BaseCacheStrategy as BaseCacheStrategy
from .mem_cache import MemCacheStrategy as MemCacheStrategy

DEFAULT_CACHE = MemCache()
DEFAULT_CACHE_STRATEGY = MemCacheStrategy()
26 changes: 24 additions & 2 deletions githubkit/cache/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,42 @@
from typing import Optional
from datetime import timedelta

from hishel import BaseStorage, AsyncBaseStorage


class BaseCache(abc.ABC):
@abc.abstractmethod
def get(self, key: str) -> Optional[str]:
raise NotImplementedError

@abc.abstractmethod
async def aget(self, key: str) -> Optional[str]:
def set(self, key: str, value: str, ex: timedelta) -> None:
raise NotImplementedError


class AsyncBaseCache(abc.ABC):
@abc.abstractmethod
def set(self, key: str, value: str, ex: timedelta) -> None:
async def aget(self, key: str) -> Optional[str]:
raise NotImplementedError

@abc.abstractmethod
async def aset(self, key: str, value: str, ex: timedelta) -> None:
raise NotImplementedError


class BaseCacheStrategy(abc.ABC):
@abc.abstractmethod
def get_cache_storage(self) -> BaseCache:
raise NotImplementedError

@abc.abstractmethod
def get_async_cache_storage(self) -> AsyncBaseCache:
raise NotImplementedError

@abc.abstractmethod
def get_hishel_storage(self) -> BaseStorage:
raise NotImplementedError

@abc.abstractmethod
def get_async_hishel_storage(self) -> AsyncBaseStorage:
raise NotImplementedError
31 changes: 29 additions & 2 deletions githubkit/cache/mem_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from dataclasses import dataclass
from datetime import datetime, timezone, timedelta

from .base import BaseCache
from hishel import InMemoryStorage, AsyncInMemoryStorage

from .base import BaseCache, AsyncBaseCache, BaseCacheStrategy


@dataclass(frozen=True)
Expand All @@ -11,7 +13,7 @@ class _Item:
expire_at: Optional[datetime] = None


class MemCache(BaseCache):
class MemCache(AsyncBaseCache, BaseCache):
"""Simple Memory Cache with Expiration Support"""

def __init__(self):
Expand All @@ -36,3 +38,28 @@ def set(self, key: str, value: str, ex: timedelta) -> None:

async def aset(self, key: str, value: str, ex: timedelta) -> None:
return self.set(key, value, ex)


class MemCacheStrategy(BaseCacheStrategy):
def __init__(self) -> None:
self._cache: Optional[MemCache] = None
self._hishel_storage: Optional[InMemoryStorage] = None
self._hishel_async_storage: Optional[AsyncInMemoryStorage] = None

def get_cache_storage(self) -> MemCache:
if self._cache is None:
self._cache = MemCache()
return self._cache

def get_async_cache_storage(self) -> MemCache:
return self.get_cache_storage()

def get_hishel_storage(self) -> InMemoryStorage:
if self._hishel_storage is None:
self._hishel_storage = InMemoryStorage()
return self._hishel_storage

def get_async_hishel_storage(self) -> AsyncInMemoryStorage:
if self._hishel_async_storage is None:
self._hishel_async_storage = AsyncInMemoryStorage()
return self._hishel_async_storage
11 changes: 11 additions & 0 deletions githubkit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from .retry import RETRY_DEFAULT
from .typing import RetryDecisionFunc
from .cache import DEFAULT_CACHE_STRATEGY, BaseCacheStrategy


@dataclass(frozen=True)
Expand All @@ -15,6 +16,7 @@ class Config:
user_agent: str
follow_redirects: bool
timeout: httpx.Timeout
cache_strategy: BaseCacheStrategy
http_cache: bool
auto_retry: Optional[RetryDecisionFunc]
rest_api_validate_body: bool
Expand Down Expand Up @@ -64,6 +66,12 @@ def build_timeout(
return timeout if isinstance(timeout, httpx.Timeout) else httpx.Timeout(timeout)


def build_cache_strategy(
cache_strategy: Optional[BaseCacheStrategy],
) -> BaseCacheStrategy:
return cache_strategy or DEFAULT_CACHE_STRATEGY


def build_auto_retry(
auto_retry: Union[bool, RetryDecisionFunc] = True,
) -> Optional[RetryDecisionFunc]:
Expand All @@ -76,12 +84,14 @@ def build_auto_retry(


def get_config(
*,
base_url: Optional[Union[str, httpx.URL]] = None,
accept_format: Optional[str] = None,
previews: Optional[list[str]] = None,
user_agent: Optional[str] = None,
follow_redirects: bool = True,
timeout: Optional[Union[float, httpx.Timeout]] = None,
cache_strategy: Optional[BaseCacheStrategy] = None,
http_cache: bool = True,
auto_retry: Union[bool, RetryDecisionFunc] = True,
rest_api_validate_body: bool = True,
Expand All @@ -92,6 +102,7 @@ def get_config(
build_user_agent(user_agent),
follow_redirects,
build_timeout(timeout),
build_cache_strategy(cache_strategy),
http_cache,
build_auto_retry(auto_retry),
rest_api_validate_body,
Expand Down
Loading

0 comments on commit 28b8b65

Please sign in to comment.