Skip to content

Commit

Permalink
Add PyJWT._{de,en}code_payload hooks (#829)
Browse files Browse the repository at this point in the history
* Add PyJWT._decode_payload hook

* Add PyJWT._encode_payload hook
  • Loading branch information
akx authored Dec 8, 2022
1 parent dd35ede commit 2fc6aa3
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 9 deletions.
49 changes: 40 additions & 9 deletions jwt/api_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,32 @@ def encode(
if isinstance(payload.get(time_claim), datetime):
payload[time_claim] = timegm(payload[time_claim].utctimetuple())

json_payload = json.dumps(
payload, separators=(",", ":"), cls=json_encoder
).encode("utf-8")
json_payload = self._encode_payload(
payload,
headers=headers,
json_encoder=json_encoder,
)

return api_jws.encode(json_payload, key, algorithm, headers, json_encoder)

def _encode_payload(
self,
payload: Dict[str, Any],
headers: Optional[Dict[str, Any]] = None,
json_encoder: Optional[Type[json.JSONEncoder]] = None,
) -> bytes:
"""
Encode a given payload to the bytes to be signed.
This method is intended to be overridden by subclasses that need to
encode the payload in a different way, e.g. compress the payload.
"""
return json.dumps(
payload,
separators=(",", ":"),
cls=json_encoder,
).encode("utf-8")

def decode_complete(
self,
jwt: str,
Expand Down Expand Up @@ -125,12 +145,7 @@ def decode_complete(
detached_payload=detached_payload,
)

try:
payload = json.loads(decoded["payload"])
except ValueError as e:
raise DecodeError(f"Invalid payload string: {e}")
if not isinstance(payload, dict):
raise DecodeError("Invalid payload string: must be a json object")
payload = self._decode_payload(decoded)

merged_options = {**self.options, **options}
self._validate_claims(
Expand All @@ -140,6 +155,22 @@ def decode_complete(
decoded["payload"] = payload
return decoded

def _decode_payload(self, decoded: Dict[str, Any]) -> Any:
"""
Decode the payload from a JWS dictionary (payload, signature, header).
This method is intended to be overridden by subclasses that need to
decode the payload in a different way, e.g. decompress compressed
payloads.
"""
try:
payload = json.loads(decoded["payload"])
except ValueError as e:
raise DecodeError(f"Invalid payload string: {e}")
if not isinstance(payload, dict):
raise DecodeError("Invalid payload string: must be a json object")
return payload

def decode(
self,
jwt: str,
Expand Down
37 changes: 37 additions & 0 deletions tests/test_compressed_jwt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import json
import zlib

from jwt import PyJWT


class CompressedPyJWT(PyJWT):
def _decode_payload(self, decoded):
return json.loads(
# wbits=-15 has zlib not worry about headers of crc's
zlib.decompress(decoded["payload"], wbits=-15).decode("utf-8")
)


def test_decodes_complete_valid_jwt_with_compressed_payload():
# Test case from https://github.com/jpadilla/pyjwt/pull/753/files
example_payload = {"hello": "world"}
example_secret = "secret"
# payload made with the pako (https://nodeca.github.io/pako/) library in Javascript:
# Buffer.from(pako.deflateRaw('{"hello": "world"}')).toString('base64')
example_jwt = (
b"eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9"
b".q1bKSM3JyVeyUlAqzy/KSVGqBQA="
b".08wHYeuh1rJXmcBcMrz6NxmbxAnCQp2rGTKfRNIkxiw="
)
decoded = CompressedPyJWT().decode_complete(
example_jwt, example_secret, algorithms=["HS256"]
)

assert decoded == {
"header": {"alg": "HS256", "typ": "JWT"},
"payload": example_payload,
"signature": (
b"\xd3\xcc\x07a\xeb\xa1\xd6\xb2W\x99\xc0\\2\xbc\xfa7"
b"\x19\x9b\xc4\t\xc2B\x9d\xab\x192\x9fD\xd2$\xc6,"
),
}

0 comments on commit 2fc6aa3

Please sign in to comment.