diff --git a/CHANGES.rst b/CHANGES.rst index 1f084c7c..5c67da11 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,3 +1,11 @@ +Version 0.5.0 +------------- + +Unreleased + +- Cache types now have configurable serializers. :pr:`63` + + Version 0.4.1 ------------- diff --git a/setup.cfg b/setup.cfg index c2ff47d1..e9792477 100644 --- a/setup.cfg +++ b/setup.cfg @@ -38,6 +38,7 @@ testpaths = tests filterwarnings = error default::DeprecationWarning:cachelib.uwsgi + default::DeprecationWarning:cachelib.redis [coverage:run] branch = True diff --git a/src/cachelib/file.py b/src/cachelib/file.py index ef17fea4..bcaae405 100644 --- a/src/cachelib/file.py +++ b/src/cachelib/file.py @@ -1,7 +1,6 @@ import errno import logging import os -import pickle import tempfile import typing as _t from hashlib import md5 @@ -9,6 +8,7 @@ from time import time from cachelib.base import BaseCache +from cachelib.serializers import FileSystemSerializer class FileSystemCache(BaseCache): @@ -32,6 +32,8 @@ class FileSystemCache(BaseCache): #: keep amount of files in a cache element _fs_count_file = "__wz_cache_count" + serializer = FileSystemSerializer() + def __init__( self, cache_dir: str, @@ -96,7 +98,8 @@ def _remove_expired(self, now: float) -> None: for fname in self._list_dir(): try: with open(fname, "rb") as f: - expires = pickle.load(f) + expires = self.serializer.load(f) + print(expires) if expires != 0 and expires < now: os.remove(fname) self._update_count(delta=-1) @@ -114,7 +117,7 @@ def _remove_older(self) -> bool: for fname in self._list_dir(): try: with open(fname, "rb") as f: - exp_fname_tuples.append((pickle.load(f), fname)) + exp_fname_tuples.append((self.serializer.load(f), fname)) except FileNotFoundError: pass except (OSError, EOFError): @@ -181,12 +184,12 @@ def get(self, key: str) -> _t.Any: filename = self._get_filename(key) try: with open(filename, "rb") as f: - pickle_time = pickle.load(f) + pickle_time = self.serializer.load(f) if pickle_time == 0 or pickle_time >= time(): - return pickle.load(f) + return self.serializer.load(f) except FileNotFoundError: pass - except (OSError, EOFError, pickle.PickleError): + except (OSError, EOFError): logging.warning( "Exception raised while handling cache file '%s'", filename, @@ -223,8 +226,8 @@ def set( suffix=self._fs_transaction_suffix, dir=self._path ) with os.fdopen(fd, "wb") as f: - pickle.dump(timeout, f, 1) - pickle.dump(value, f, pickle.HIGHEST_PROTOCOL) + self.serializer.dump(timeout, f) # this returns bool + self.serializer.dump(value, f) os.replace(tmp, filename) os.chmod(filename, self._mode) fsize = Path(filename).stat().st_size @@ -259,14 +262,14 @@ def has(self, key: str) -> bool: filename = self._get_filename(key) try: with open(filename, "rb") as f: - pickle_time = pickle.load(f) + pickle_time = self.serializer.load(f) if pickle_time == 0 or pickle_time >= time(): return True else: return False except FileNotFoundError: # if there is no file there is no key return False - except (OSError, EOFError, pickle.PickleError): + except (OSError, EOFError): logging.warning( "Exception raised while handling cache file '%s'", filename, diff --git a/src/cachelib/redis.py b/src/cachelib/redis.py index 41afbf3e..ddc641f4 100644 --- a/src/cachelib/redis.py +++ b/src/cachelib/redis.py @@ -1,7 +1,8 @@ -import pickle import typing as _t +import warnings from cachelib.base import BaseCache +from cachelib.serializers import RedisSerializer class RedisCache(BaseCache): @@ -26,6 +27,8 @@ class RedisCache(BaseCache): Any additional keyword arguments will be passed to ``redis.Redis``. """ + serializer = RedisSerializer() + def __init__( self, host: _t.Any = "localhost", @@ -60,29 +63,22 @@ def _normalize_timeout(self, timeout: _t.Optional[int]) -> int: return timeout def dump_object(self, value: _t.Any) -> bytes: - """Dumps an object into a string for redis. By default it serializes - integers as regular string and pickle dumps everything else. - """ - if isinstance(type(value), int): - return str(value).encode("ascii") - return b"!" + pickle.dumps(value) - - def load_object(self, value: _t.Optional[bytes]) -> _t.Any: - """The reversal of :meth:`dump_object`. This might be called with - None. - """ - if value is None: - return None - if value.startswith(b"!"): - try: - return pickle.loads(value[1:]) - except pickle.PickleError: - return None - try: - return int(value) - except ValueError: - # before 0.8 we did not have serialization. Still support that. - return value + warnings.warn( + "'dump_object' is deprecated and will be removed in the future." + "This is a proxy call to 'RedisCache.serializer.dumps'", + DeprecationWarning, + stacklevel=2, + ) + return self.serializer.dumps(value) + + def load_object(self, value: _t.Any) -> _t.Any: + warnings.warn( + "'load_object' is deprecated and will be removed in the future." + "This is a proxy call to 'RedisCache.serializer.loads'", + DeprecationWarning, + stacklevel=2, + ) + return self.serializer.loads(value) def get(self, key: str) -> _t.Any: return self.load_object(self._client.get(self.key_prefix + key)) diff --git a/src/cachelib/serializers.py b/src/cachelib/serializers.py new file mode 100644 index 00000000..b3e0d55b --- /dev/null +++ b/src/cachelib/serializers.py @@ -0,0 +1,105 @@ +import logging +import pickle +import typing as _t + + +class BaseSerializer: + """This is the base interface for all default serializers. + + BaseSerializer.load and BaseSerializer.dump will + default to pickle.load and pickle.dump. This is currently + used only by FileSystemCache which dumps/loads to/from a file stream. + """ + + def _warn(self, e: pickle.PickleError) -> None: + logging.warning( + f"An exception has been raised during a pickling operation: {e}" + ) + + def dump( + self, value: int, f: _t.IO, protocol: int = pickle.HIGHEST_PROTOCOL + ) -> None: + try: + pickle.dump(value, f, protocol) + except (pickle.PickleError, pickle.PicklingError) as e: + self._warn(e) + + def load(self, f: _t.BinaryIO) -> _t.Any: + try: + data = pickle.load(f) + except pickle.PickleError as e: + self._warn(e) + return None + else: + return data + + """BaseSerializer.loads and BaseSerializer.dumps + work on top of pickle.loads and pickle.dumps. Dumping/loading + strings and byte strings is the default for most cache types. + """ + + def dumps(self, value: _t.Any, protocol: int = pickle.HIGHEST_PROTOCOL) -> bytes: + try: + serialized = pickle.dumps(value, protocol) + except (pickle.PickleError, pickle.PicklingError) as e: + self._warn(e) + return serialized + + def loads(self, bvalue: bytes) -> _t.Any: + try: + data = pickle.loads(bvalue) + except pickle.PickleError as e: + self._warn(e) + return None + else: + return data + + +"""Default serializers for each cache type. + +The following classes can be used to further customize +serialiation behaviour. Alternatively, any serializer can be +overriden in order to use a custom serializer with a different +strategy altogether. +""" + + +class UWSGISerializer(BaseSerializer): + """Default serializer for UWSGICache.""" + + +class SimpleSerializer(BaseSerializer): + """Default serializer for SimpleCache.""" + + +class FileSystemSerializer(BaseSerializer): + """Default serializer for FileSystemCache.""" + + +class RedisSerializer(BaseSerializer): + """Default serializer for RedisCache.""" + + def dumps(self, value: _t.Any, protocol: int = pickle.HIGHEST_PROTOCOL) -> bytes: + """Dumps an object into a string for redis. By default it serializes + integers as regular string and pickle dumps everything else. + """ + if isinstance(type(value), int): + return str(value).encode("ascii") + return b"!" + pickle.dumps(value, protocol) + + def loads(self, value: _t.Optional[bytes]) -> _t.Any: + """The reversal of :meth:`dump_object`. This might be called with + None. + """ + if value is None: + return None + if value.startswith(b"!"): + try: + return pickle.loads(value[1:]) + except pickle.PickleError: + return None + try: + return int(value) + except ValueError: + # before 0.8 we did not have serialization. Still support that. + return value diff --git a/src/cachelib/simple.py b/src/cachelib/simple.py index f3d459c1..14302cfc 100644 --- a/src/cachelib/simple.py +++ b/src/cachelib/simple.py @@ -1,8 +1,8 @@ -import pickle import typing as _t from time import time from cachelib.base import BaseCache +from cachelib.serializers import SimpleSerializer class SimpleCache(BaseCache): @@ -19,7 +19,13 @@ class SimpleCache(BaseCache): 0 indicates that the cache never expires. """ - def __init__(self, threshold: int = 500, default_timeout: int = 300): + serializer = SimpleSerializer() + + def __init__( + self, + threshold: int = 500, + default_timeout: int = 300, + ): BaseCache.__init__(self, default_timeout) self._cache: _t.Dict[str, _t.Any] = {} self._threshold = threshold or 500 # threshold = 0 @@ -62,8 +68,8 @@ def get(self, key: str) -> _t.Any: try: expires, value = self._cache[key] if expires == 0 or expires > time(): - return pickle.loads(value) - except (KeyError, pickle.PickleError): + return self.serializer.loads(value) + except KeyError: return None def set( @@ -71,13 +77,13 @@ def set( ) -> _t.Optional[bool]: expires = self._normalize_timeout(timeout) self._prune() - self._cache[key] = (expires, pickle.dumps(value, pickle.HIGHEST_PROTOCOL)) + self._cache[key] = (expires, self.serializer.dumps(value)) return True def add(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> bool: expires = self._normalize_timeout(timeout) self._prune() - item = (expires, pickle.dumps(value, pickle.HIGHEST_PROTOCOL)) + item = (expires, self.serializer.dumps(value)) if key in self._cache: return False self._cache.setdefault(key, item) diff --git a/src/cachelib/uwsgi.py b/src/cachelib/uwsgi.py index beab19b2..9d2600c9 100644 --- a/src/cachelib/uwsgi.py +++ b/src/cachelib/uwsgi.py @@ -1,8 +1,8 @@ -import pickle import platform import typing as _t from cachelib.base import BaseCache +from cachelib.serializers import UWSGISerializer class UWSGICache(BaseCache): @@ -20,7 +20,13 @@ class UWSGICache(BaseCache): the cache. """ - def __init__(self, default_timeout: int = 300, cache: str = ""): + serializer = UWSGISerializer() + + def __init__( + self, + default_timeout: int = 300, + cache: str = "", + ): BaseCache.__init__(self, default_timeout) if platform.python_implementation() == "PyPy": @@ -44,7 +50,7 @@ def get(self, key: str) -> _t.Any: rv = self._uwsgi.cache_get(key, self.cache) if rv is None: return - return pickle.loads(rv) + return self.serializer.loads(rv) def delete(self, key: str) -> bool: return bool(self._uwsgi.cache_del(key, self.cache)) @@ -53,14 +59,20 @@ def set( self, key: str, value: _t.Any, timeout: _t.Optional[int] = None ) -> _t.Optional[bool]: result = self._uwsgi.cache_update( - key, pickle.dumps(value), self._normalize_timeout(timeout), self.cache + key, + self.serializer.dumps(value), + self._normalize_timeout(timeout), + self.cache, ) # type: bool return result def add(self, key: str, value: _t.Any, timeout: _t.Optional[int] = None) -> bool: return bool( self._uwsgi.cache_set( - key, pickle.dumps(value), self._normalize_timeout(timeout), self.cache + key, + self.serializer.dumps(value), + self._normalize_timeout(timeout), + self.cache, ) ) diff --git a/tests/test_file_system_cache.py b/tests/test_file_system_cache.py index c05fa68f..a1993aaa 100644 --- a/tests/test_file_system_cache.py +++ b/tests/test_file_system_cache.py @@ -1,3 +1,4 @@ +import os from time import sleep import pytest @@ -8,10 +9,38 @@ from cachelib import FileSystemCache -@pytest.fixture(autouse=True) +class SillySerializer: + """A pointless serializer only for testing""" + + def dump(self, value, fs): + fs.write(f"{repr(value)}{os.linesep}".encode()) + + def load(self, fs): + try: + loaded = eval(fs.readline().decode()) + # When all file content has been read eval will + # turn the EOFError into SyntaxError wich is not + # handled by cachelib + except SyntaxError as e: + raise EOFError from e + return loaded + + +class CustomCache(FileSystemCache): + """Our custom cache client with non-default serializer""" + + # overwrite serializer + serializer = SillySerializer() + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +@pytest.fixture(autouse=True, params=[FileSystemCache, CustomCache]) def cache_factory(request, tmpdir): def _factory(self, *args, **kwargs): - return FileSystemCache(tmpdir, *args, **kwargs) + client = request.param(tmpdir, *args, **kwargs) + return client request.cls.cache_factory = _factory diff --git a/tests/test_redis_cache.py b/tests/test_redis_cache.py index f33978a0..25c3f94e 100644 --- a/tests/test_redis_cache.py +++ b/tests/test_redis_cache.py @@ -6,10 +6,32 @@ from cachelib import RedisCache -@pytest.fixture(autouse=True) +class SillySerializer: + """A pointless serializer only for testing""" + + def dumps(self, value): + return repr(value).encode() + + def loads(self, bvalue): + if bvalue is None: + return None + return eval(bvalue.decode()) + + +class CustomCache(RedisCache): + """Our custom cache client with non-default serializer""" + + # overwrite serializer + serializer = SillySerializer() + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +@pytest.fixture(autouse=True, params=[RedisCache, CustomCache]) def cache_factory(request): def _factory(self, *args, **kwargs): - rc = RedisCache(*args, port=6360, **kwargs) + rc = request.param(*args, port=6360, **kwargs) rc._client.flushdb() return rc diff --git a/tests/test_simple_cache.py b/tests/test_simple_cache.py index 1f6de3d8..4f4de68c 100644 --- a/tests/test_simple_cache.py +++ b/tests/test_simple_cache.py @@ -8,10 +8,30 @@ from cachelib import SimpleCache -@pytest.fixture(autouse=True) +class SillySerializer: + """A pointless serializer only for testing""" + + def dumps(self, value): + return repr(value).encode() + + def loads(self, bvalue): + return eval(bvalue.decode()) + + +class CustomCache(SimpleCache): + """Our custom cache client with non-default serializer""" + + # overwrite serializer + serializer = SillySerializer() + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +@pytest.fixture(autouse=True, params=[SimpleCache, CustomCache]) def cache_factory(request): def _factory(self, *args, **kwargs): - return SimpleCache(*args, **kwargs) + return request.param(*args, **kwargs) request.cls.cache_factory = _factory diff --git a/tests/test_uwsgi_cache.py b/tests/test_uwsgi_cache.py index da8c7dbc..c82a09a7 100644 --- a/tests/test_uwsgi_cache.py +++ b/tests/test_uwsgi_cache.py @@ -6,10 +6,30 @@ from cachelib import UWSGICache -@pytest.fixture(autouse=True) +class SillySerializer: + """A pointless serializer only for testing""" + + def dumps(self, value): + return repr(value).encode() + + def loads(self, bvalue): + return eval(bvalue.decode()) + + +class CustomCache(UWSGICache): + """Our custom cache client with non-default serializer""" + + # overwrite serializer + serializer = SillySerializer() + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +@pytest.fixture(autouse=True, params=[UWSGICache, CustomCache]) def cache_factory(request): def _factory(self, *args, **kwargs): - uwc = UWSGICache(*args, **kwargs) + uwc = request.param(*args, **kwargs) uwc.clear() return uwc