diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 708269ab..d591f332 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -16,6 +16,8 @@ Changed Fixed ~~~~~ +- Encode EC keys with a fixed bit length by @etianen in `#990 `__ + Added ~~~~~ diff --git a/jwt/algorithms.py b/jwt/algorithms.py index 9be50b20..9fbdebb7 100644 --- a/jwt/algorithms.py +++ b/jwt/algorithms.py @@ -581,13 +581,20 @@ def to_jwk(key_obj: AllowedECKeys, as_dict: bool = False) -> JWKDict | str: obj: dict[str, Any] = { "kty": "EC", "crv": crv, - "x": to_base64url_uint(public_numbers.x).decode(), - "y": to_base64url_uint(public_numbers.y).decode(), + "x": to_base64url_uint( + public_numbers.x, + bit_length=key_obj.curve.key_size, + ).decode(), + "y": to_base64url_uint( + public_numbers.y, + bit_length=key_obj.curve.key_size, + ).decode(), } if isinstance(key_obj, EllipticCurvePrivateKey): obj["d"] = to_base64url_uint( - key_obj.private_numbers().private_value + key_obj.private_numbers().private_value, + bit_length=key_obj.curve.key_size, ).decode() if as_dict: diff --git a/jwt/utils.py b/jwt/utils.py index d469139b..56e89bb7 100644 --- a/jwt/utils.py +++ b/jwt/utils.py @@ -1,7 +1,7 @@ import base64 import binascii import re -from typing import Union +from typing import Optional, Union try: from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurve @@ -37,11 +37,11 @@ def base64url_encode(input: bytes) -> bytes: return base64.urlsafe_b64encode(input).replace(b"=", b"") -def to_base64url_uint(val: int) -> bytes: +def to_base64url_uint(val: int, *, bit_length: Optional[int] = None) -> bytes: if val < 0: raise ValueError("Must be a positive integer") - int_bytes = bytes_from_int(val) + int_bytes = bytes_from_int(val, bit_length=bit_length) if len(int_bytes) == 0: int_bytes = b"\x00" @@ -63,13 +63,10 @@ def bytes_to_number(string: bytes) -> int: return int(binascii.b2a_hex(string), 16) -def bytes_from_int(val: int) -> bytes: - remaining = val - byte_length = 0 - - while remaining != 0: - remaining >>= 8 - byte_length += 1 +def bytes_from_int(val: int, *, bit_length: Optional[int] = None) -> bytes: + if bit_length is None: + bit_length = val.bit_length() + byte_length = (bit_length + 7) // 8 return val.to_bytes(byte_length, "big", signed=False)