diff --git a/src/itsdangerous/serializer.py b/src/itsdangerous/serializer.py index 362bf79..8705dd0 100644 --- a/src/itsdangerous/serializer.py +++ b/src/itsdangerous/serializer.py @@ -10,13 +10,20 @@ from .signer import _make_keys_list from .signer import Signer +_TAnyStr = t.TypeVar("_TAnyStr", str, bytes, covariant=True) -def is_text_serializer(serializer: t.Any) -> bool: + +class _PDataSerializer(t.Protocol[_TAnyStr]): + def loads(self, payload: str | bytes) -> t.Any: ... + def dumps(self, obj: t.Any, **kwargs: t.Any) -> _TAnyStr: ... + + +def is_text_serializer(serializer: _PDataSerializer[t.Any]) -> bool: """Checks whether a serializer generates text or binary.""" return isinstance(serializer.dumps({}), str) -class Serializer: +class Serializer(t.Generic[_TAnyStr]): """A serializer wraps a :class:`~itsdangerous.signer.Signer` to enable serializing and securely signing data other than bytes. It can unsign to verify that the data hasn't been changed. @@ -71,7 +78,7 @@ class Serializer: #: The default serialization module to use to serialize data to a #: string internally. The default is :mod:`json`, but can be changed #: to any object that provides ``dumps`` and ``loads`` methods. - default_serializer: t.Any = json + default_serializer: _PDataSerializer[_TAnyStr] = json # type: ignore[assignment] #: The default ``Signer`` class to instantiate when signing data. #: The default is :class:`itsdangerous.signer.Signer`. @@ -82,11 +89,43 @@ class Serializer: dict[str, t.Any] | tuple[type[Signer], dict[str, t.Any]] | type[Signer] ] = [] + # Tell type checkers that the default type is Serializer[str] if no + # data serializer is provided. + @t.overload + def __init__( + self: Serializer[str], + secret_key: str | bytes | cabc.Iterable[str] | cabc.Iterable[bytes], + salt: str | bytes | None = b"itsdangerous", + serializer: None = None, + serializer_kwargs: dict[str, t.Any] | None = None, + signer: type[Signer] | None = None, + signer_kwargs: dict[str, t.Any] | None = None, + fallback_signers: list[ + dict[str, t.Any] | tuple[type[Signer], dict[str, t.Any]] | type[Signer] + ] + | None = None, + ): ... + + @t.overload + def __init__( + self: Serializer[_TAnyStr], + secret_key: str | bytes | cabc.Iterable[str] | cabc.Iterable[bytes], + salt: str | bytes | None = b"itsdangerous", + serializer: _PDataSerializer[_TAnyStr] = ..., + serializer_kwargs: dict[str, t.Any] | None = None, + signer: type[Signer] | None = None, + signer_kwargs: dict[str, t.Any] | None = None, + fallback_signers: list[ + dict[str, t.Any] | tuple[type[Signer], dict[str, t.Any]] | type[Signer] + ] + | None = None, + ): ... + def __init__( self, secret_key: str | bytes | cabc.Iterable[str] | cabc.Iterable[bytes], salt: str | bytes | None = b"itsdangerous", - serializer: t.Any = None, + serializer: _PDataSerializer[_TAnyStr] | None = None, serializer_kwargs: dict[str, t.Any] | None = None, signer: type[Signer] | None = None, signer_kwargs: dict[str, t.Any] | None = None, @@ -111,7 +150,7 @@ def __init__( if serializer is None: serializer = self.default_serializer - self.serializer: t.Any = serializer + self.serializer: _PDataSerializer[_TAnyStr] = serializer self.is_text_serializer: bool = is_text_serializer(serializer) if signer is None: @@ -135,7 +174,9 @@ def secret_key(self) -> bytes: """ return self.secret_keys[-1] - def load_payload(self, payload: bytes, serializer: t.Any | None = None) -> t.Any: + def load_payload( + self, payload: bytes, serializer: _PDataSerializer[_TAnyStr] | None = None + ) -> t.Any: """Loads the encoded object. This function raises :class:`.BadPayload` if the payload is not valid. The ``serializer`` parameter can be used to override the serializer @@ -199,7 +240,7 @@ def iter_unsigners(self, salt: str | bytes | None = None) -> cabc.Iterator[Signe for secret_key in self.secret_keys: yield fallback(secret_key, salt=salt, **kwargs) - def dumps(self, obj: t.Any, salt: str | bytes | None = None) -> str | bytes: + def dumps(self, obj: t.Any, salt: str | bytes | None = None) -> _TAnyStr: """Returns a signed string serialized with the internal serializer. The return value can be either a byte or unicode string depending on the format of the internal serializer. @@ -208,9 +249,9 @@ def dumps(self, obj: t.Any, salt: str | bytes | None = None) -> str | bytes: rv = self.make_signer(salt).sign(payload) if self.is_text_serializer: - return rv.decode("utf-8") + return rv.decode("utf-8") # type: ignore[return-value] - return rv + return rv # type: ignore[return-value] def dump(self, obj: t.Any, f: t.IO[t.Any], salt: str | bytes | None = None) -> None: """Like :meth:`dumps` but dumps into a file. The file handle has diff --git a/src/itsdangerous/timed.py b/src/itsdangerous/timed.py index 8188b97..51dac0e 100644 --- a/src/itsdangerous/timed.py +++ b/src/itsdangerous/timed.py @@ -17,6 +17,8 @@ from .serializer import Serializer from .signer import Signer +_TAnyStr = t.TypeVar("_TAnyStr", str, bytes, covariant=True) + class TimestampSigner(Signer): """Works like the regular :class:`.Signer` but also records the time @@ -166,7 +168,7 @@ def validate(self, signed_value: str | bytes, max_age: int | None = None) -> boo return False -class TimedSerializer(Serializer): +class TimedSerializer(Serializer[_TAnyStr]): """Uses :class:`TimestampSigner` instead of the default :class:`.Signer`. """ diff --git a/src/itsdangerous/url_safe.py b/src/itsdangerous/url_safe.py index e33b241..56a0793 100644 --- a/src/itsdangerous/url_safe.py +++ b/src/itsdangerous/url_safe.py @@ -7,17 +7,18 @@ from .encoding import base64_decode from .encoding import base64_encode from .exc import BadPayload +from .serializer import _PDataSerializer from .serializer import Serializer from .timed import TimedSerializer -class URLSafeSerializerMixin(Serializer): +class URLSafeSerializerMixin(Serializer[str]): """Mixed in with a regular serializer it will attempt to zlib compress the string to make it shorter if necessary. It will also base64 encode the string so that it can safely be placed in a URL. """ - default_serializer = _CompactJSON + default_serializer: _PDataSerializer[str] = _CompactJSON def load_payload( self, @@ -68,14 +69,14 @@ def dump_payload(self, obj: t.Any) -> bytes: return base64d -class URLSafeSerializer(URLSafeSerializerMixin, Serializer): +class URLSafeSerializer(URLSafeSerializerMixin, Serializer[str]): """Works like :class:`.Serializer` but dumps and loads into a URL safe string consisting of the upper and lowercase character of the alphabet as well as ``'_'``, ``'-'`` and ``'.'``. """ -class URLSafeTimedSerializer(URLSafeSerializerMixin, TimedSerializer): +class URLSafeTimedSerializer(URLSafeSerializerMixin, TimedSerializer[str]): """Works like :class:`.TimedSerializer` but dumps and loads into a URL safe string consisting of the upper and lowercase character of the alphabet as well as ``'_'``, ``'-'`` and ``'.'``.