Skip to content

Commit

Permalink
Add type annotations to cache_handler.py (#1153)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackDyre authored Jan 18, 2025
1 parent 935e3c1 commit 3fe3eb8
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 67 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ Rebasing master onto v3 doesn't require a changelog update.
- featured_playlists
- category_playlists
- Added FAQ entry for inaccessible playlists
- Type annotations to `spotipy.cache_handler`

### Fixed

Expand Down
145 changes: 78 additions & 67 deletions spotipy/cache_handler.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,40 @@
__all__ = [
'CacheHandler',
'CacheFileHandler',
'DjangoSessionCacheHandler',
'FlaskSessionCacheHandler',
'MemoryCacheHandler',
'RedisCacheHandler',
'MemcacheCacheHandler']
from __future__ import annotations

import errno
import json
import logging
import os
from spotipy.util import CLIENT_CREDS_ENV_VARS

from redis import RedisError
from abc import ABC, abstractmethod
from json import JSONEncoder
from typing import TypedDict
import redis
from redis import RedisError
import redis.client

from .util import CLIENT_CREDS_ENV_VARS

__all__ = [
"CacheHandler",
"CacheFileHandler",
"DjangoSessionCacheHandler",
"FlaskSessionCacheHandler",
"MemoryCacheHandler",
"RedisCacheHandler",
"MemcacheCacheHandler",
]

logger = logging.getLogger(__name__)


class TokenInfo(TypedDict):
access_token: str
token_type: str
expires_in: int
scope: str
expires_at: int
refresh_token: str


class CacheHandler(ABC):
"""
An abstraction layer for handling the caching and retrieval of
Expand All @@ -30,49 +46,43 @@ class CacheHandler(ABC):
"""

@abstractmethod
def get_cached_token(self):
"""
Get and return a token_info dictionary object.
"""
def get_cached_token(self) -> TokenInfo | None:
"""Get and return a token_info dictionary object."""

@abstractmethod
def save_token_to_cache(self, token_info):
"""
Save a token_info dictionary object to the cache and return None.
"""
def save_token_to_cache(self, token_info: TokenInfo) -> None:
"""Save a token_info dictionary object to the cache and return None."""


class CacheFileHandler(CacheHandler):
"""
Handles reading and writing cached Spotify authorization tokens
as json files on disk.
"""

def __init__(self,
cache_path=None,
username=None,
encoder_cls=None):
"""Read and write cached Spotify authorization tokens as json files on disk."""

def __init__(
self,
cache_path: str | None = None,
username: str | None = None,
encoder_cls: type[JSONEncoder] | None = None,
) -> None:
"""
Parameters:
* cache_path: May be supplied, will otherwise be generated
(takes precedence over `username`)
* username: May be supplied or set as environment variable
(will set `cache_path` to `.cache-{username}`)
* encoder_cls: May be supplied as a means of overwriting the
default serializer used for writing tokens to disk
Initialize CacheFileHandler instance.
:param cache_path: (Optional) Path to cache. (Will override 'username')
:param username: (Optional) Client username. (Can also be supplied via env var.)
:param encoder_cls: (Optional) JSON encoder class to override default.
"""
self.encoder_cls = encoder_cls
if cache_path:
self.cache_path = cache_path
else:
cache_path = ".cache"
username = (username or os.getenv(CLIENT_CREDS_ENV_VARS["client_username"]))
username = username or os.getenv(CLIENT_CREDS_ENV_VARS["client_username"])
if username:
cache_path += f"-{username}"
self.cache_path = cache_path

def get_cached_token(self):
token_info = None
def get_cached_token(self) -> TokenInfo | None:
"""Get cached token from file."""
token_info: TokenInfo | None = None

try:
f = open(self.cache_path)
Expand All @@ -88,7 +98,8 @@ def get_cached_token(self):

return token_info

def save_token_to_cache(self, token_info):
def save_token_to_cache(self, token_info: TokenInfo) -> None:
"""Save token cache to file."""
try:
f = open(self.cache_path, "w")
f.write(json.dumps(token_info, cls=self.encoder_cls))
Expand All @@ -98,23 +109,22 @@ def save_token_to_cache(self, token_info):


class MemoryCacheHandler(CacheHandler):
"""
A cache handler that simply stores the token info in memory as an
instance attribute of this class. The token info will be lost when this
instance is freed.
"""
"""Cache handler that stores the token non-persistently as an instance attribute."""

def __init__(self, token_info=None):
def __init__(self, token_info: TokenInfo | None = None) -> None:
"""
Parameters:
* token_info: The token info to store in memory. Can be None.
Initialize MemoryCacheHandler instance.
:param token_info: Optional initial cached token
"""
self.token_info = token_info

def get_cached_token(self):
def get_cached_token(self) -> TokenInfo | None:
"""Retrieve the cached token from the instance."""
return self.token_info

def save_token_to_cache(self, token_info):
def save_token_to_cache(self, token_info: TokenInfo) -> None:
"""Cache the token in this instance."""
self.token_info = token_info


Expand All @@ -137,15 +147,15 @@ def __init__(self, request):
def get_cached_token(self):
token_info = None
try:
token_info = self.request.session['token_info']
token_info = self.request.session["token_info"]
except KeyError:
logger.debug("Token not found in the session")

return token_info

def save_token_to_cache(self, token_info):
try:
self.request.session['token_info'] = token_info
self.request.session["token_info"] = token_info
except Exception as e:
logger.warning(f"Error saving token to cache: {e}")

Expand Down Expand Up @@ -176,33 +186,32 @@ def save_token_to_cache(self, token_info):


class RedisCacheHandler(CacheHandler):
"""
A cache handler that stores the token info in the Redis.
"""
"""A cache handler that stores the token info in the Redis."""

def __init__(self, redis, key=None):
def __init__(self, redis_obj: redis.client.Redis, key: str | None = None) -> None:
"""
Parameters:
* redis: Redis object provided by redis-py library
(https://github.com/redis/redis-py)
* key: May be supplied, will otherwise be generated
(takes precedence over `token_info`)
Initialize RedisCacheHandler instance.
:param redis: The Redis object to function as the cache
:param key: (Optional) The key to used to store the token in the cache
"""
self.redis = redis
self.key = key if key else 'token_info'
self.redis = redis_obj
self.key = key or "token_info"

def get_cached_token(self):
def get_cached_token(self) -> TokenInfo | None:
"""Fetch cache token from the Redis."""
token_info = None
try:
token_info = self.redis.get(self.key)
if token_info:
return json.loads(token_info)
if token_info is not None:
token_info = json.loads(token_info)
except RedisError as e:
logger.warning(f"Error getting token from cache: {e}")

return token_info

def save_token_to_cache(self, token_info):
def save_token_to_cache(self, token_info: TokenInfo) -> None:
"""Cache token in the Redis."""
try:
self.redis.set(self.key, json.dumps(token_info))
except RedisError as e:
Expand All @@ -222,10 +231,11 @@ def __init__(self, memcache, key=None) -> None:
(takes precedence over `token_info`)
"""
self.memcache = memcache
self.key = key or 'token_info'
self.key = key or "token_info"

def get_cached_token(self):
from pymemcache import MemcacheError

try:
token_info = self.memcache.get(self.key)
if token_info:
Expand All @@ -235,6 +245,7 @@ def get_cached_token(self):

def save_token_to_cache(self, token_info):
from pymemcache import MemcacheError

try:
self.memcache.set(self.key, json.dumps(token_info))
except MemcacheError as e:
Expand Down

0 comments on commit 3fe3eb8

Please sign in to comment.