Skip to content

Commit

Permalink
Add as_dict option to Algorithm.to_jwk (#881)
Browse files Browse the repository at this point in the history
* Add `as_dict` option to `Algorithm.to_jwt`

* Update unit tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixup! Add `as_dict` option to `Algorithm.to_jwt`

* fixup! Add `as_dict` option to `Algorithm.to_jwt`

* fixup! Update unit tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix type errors

* Fix tox test errors

* Fix typing for Python 3.7

* Add OKP jwk tests

* Add `pragma: no cover` to method overloads

* Add pragma: no cover to exclude lines

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
fluxth and pre-commit-ci[bot] authored May 9, 2023
1 parent 6a27341 commit c35e59b
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 59 deletions.
139 changes: 109 additions & 30 deletions jwt/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import hashlib
import hmac
import json
import sys
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, ClassVar, NoReturn, cast
from typing import TYPE_CHECKING, Any, ClassVar, NoReturn, Union, cast, overload

from .exceptions import InvalidKeyError
from .types import HashlibHash, JWKDict
Expand All @@ -20,6 +21,12 @@
to_base64url_uint,
)

if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal


try:
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.backends import default_backend
Expand Down Expand Up @@ -184,9 +191,21 @@ def verify(self, msg: bytes, key: Any, sig: bytes) -> bool:
for the specified message and key values.
"""

@overload
@staticmethod
@abstractmethod
def to_jwk(key_obj, as_dict: Literal[True]) -> JWKDict:
... # pragma: no cover

@overload
@staticmethod
@abstractmethod
def to_jwk(key_obj) -> str:
def to_jwk(key_obj, as_dict: Literal[False] = False) -> str:
... # pragma: no cover

@staticmethod
@abstractmethod
def to_jwk(key_obj, as_dict: bool = False) -> Union[JWKDict, str]:
"""
Serializes a given key into a JWK
"""
Expand Down Expand Up @@ -221,7 +240,7 @@ def verify(self, msg: bytes, key: None, sig: bytes) -> bool:
return False

@staticmethod
def to_jwk(key_obj: Any) -> NoReturn:
def to_jwk(key_obj: Any, as_dict: bool = False) -> NoReturn:
raise NotImplementedError()

@staticmethod
Expand Down Expand Up @@ -253,14 +272,27 @@ def prepare_key(self, key: str | bytes) -> bytes:

return key_bytes

@overload
@staticmethod
def to_jwk(key_obj: str | bytes) -> str:
return json.dumps(
{
"k": base64url_encode(force_bytes(key_obj)).decode(),
"kty": "oct",
}
)
def to_jwk(key_obj: str | bytes, as_dict: Literal[True]) -> JWKDict:
... # pragma: no cover

@overload
@staticmethod
def to_jwk(key_obj: str | bytes, as_dict: Literal[False] = False) -> str:
... # pragma: no cover

@staticmethod
def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> Union[JWKDict, str]:
jwk = {
"k": base64url_encode(force_bytes(key_obj)).decode(),
"kty": "oct",
}

if as_dict:
return jwk
else:
return json.dumps(jwk)

@staticmethod
def from_jwk(jwk: str | JWKDict) -> bytes:
Expand Down Expand Up @@ -320,8 +352,20 @@ def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys:
except ValueError:
return cast(RSAPublicKey, load_pem_public_key(key_bytes))

@overload
@staticmethod
def to_jwk(key_obj: AllowedRSAKeys) -> str:
def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[True]) -> JWKDict:
... # pragma: no cover

@overload
@staticmethod
def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[False] = False) -> str:
... # pragma: no cover

@staticmethod
def to_jwk(
key_obj: AllowedRSAKeys, as_dict: bool = False
) -> Union[JWKDict, str]:
obj: dict[str, Any] | None = None

if hasattr(key_obj, "private_numbers"):
Expand Down Expand Up @@ -354,7 +398,10 @@ def to_jwk(key_obj: AllowedRSAKeys) -> str:
else:
raise InvalidKeyError("Not a public or private key")

return json.dumps(obj)
if as_dict:
return obj
else:
return json.dumps(obj)

@staticmethod
def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys:
Expand Down Expand Up @@ -503,8 +550,20 @@ def verify(self, msg: bytes, key: "AllowedECKeys", sig: bytes) -> bool:
except InvalidSignature:
return False

@overload
@staticmethod
def to_jwk(key_obj: AllowedECKeys) -> str:
def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[True]) -> JWKDict:
... # pragma: no cover

@overload
@staticmethod
def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[False] = False) -> str:
... # pragma: no cover

@staticmethod
def to_jwk(
key_obj: AllowedECKeys, as_dict: bool = False
) -> Union[JWKDict, str]:
if isinstance(key_obj, EllipticCurvePrivateKey):
public_numbers = key_obj.public_key().public_numbers()
elif isinstance(key_obj, EllipticCurvePublicKey):
Expand Down Expand Up @@ -535,7 +594,10 @@ def to_jwk(key_obj: AllowedECKeys) -> str:
key_obj.private_numbers().private_value
).decode()

return json.dumps(obj)
if as_dict:
return obj
else:
return json.dumps(obj)

@staticmethod
def from_jwk(jwk: str | JWKDict) -> AllowedECKeys:
Expand Down Expand Up @@ -707,21 +769,35 @@ def verify(
except InvalidSignature:
return False

@overload
@staticmethod
def to_jwk(key: AllowedOKPKeys) -> str:
def to_jwk(key: AllowedOKPKeys, as_dict: Literal[True]) -> JWKDict:
... # pragma: no cover

@overload
@staticmethod
def to_jwk(key: AllowedOKPKeys, as_dict: Literal[False] = False) -> str:
... # pragma: no cover

@staticmethod
def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> Union[JWKDict, str]:
if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)):
x = key.public_bytes(
encoding=Encoding.Raw,
format=PublicFormat.Raw,
)
crv = "Ed25519" if isinstance(key, Ed25519PublicKey) else "Ed448"
return json.dumps(
{
"x": base64url_encode(force_bytes(x)).decode(),
"kty": "OKP",
"crv": crv,
}
)

obj = {
"x": base64url_encode(force_bytes(x)).decode(),
"kty": "OKP",
"crv": crv,
}

if as_dict:
return obj
else:
return json.dumps(obj)

if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)):
d = key.private_bytes(
Expand All @@ -736,14 +812,17 @@ def to_jwk(key: AllowedOKPKeys) -> str:
)

crv = "Ed25519" if isinstance(key, Ed25519PrivateKey) else "Ed448"
return json.dumps(
{
"x": base64url_encode(force_bytes(x)).decode(),
"d": base64url_encode(force_bytes(d)).decode(),
"kty": "OKP",
"crv": crv,
}
)
obj = {
"x": base64url_encode(force_bytes(x)).decode(),
"d": base64url_encode(force_bytes(d)).decode(),
"kty": "OKP",
"crv": crv,
}

if as_dict:
return obj
else:
return json.dumps(obj)

raise InvalidKeyError("Not a public or private key")

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ source = ["jwt", ".tox/*/site-packages"]

[tool.coverage.report]
show_missing = true
exclude_lines = ["if TYPE_CHECKING:"]
exclude_lines = ["if TYPE_CHECKING:", "pragma: no cover"]

[tool.isort]
profile = "black"
Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ zip_safe = false
include_package_data = true
python_requires = >=3.7
packages = find:
install_requires =
typing_extensions; python_version<="3.7"

[options.package_data]
* = py.typed
Expand Down
Loading

0 comments on commit c35e59b

Please sign in to comment.