diff --git a/jwt/algorithms.py b/jwt/algorithms.py index 9be50b20..9fbdebb7 100644 --- a/jwt/algorithms.py +++ b/jwt/algorithms.py @@ -581,13 +581,20 @@ def to_jwk(key_obj: AllowedECKeys, as_dict: bool = False) -> JWKDict | str: obj: dict[str, Any] = { "kty": "EC", "crv": crv, - "x": to_base64url_uint(public_numbers.x).decode(), - "y": to_base64url_uint(public_numbers.y).decode(), + "x": to_base64url_uint( + public_numbers.x, + bit_length=key_obj.curve.key_size, + ).decode(), + "y": to_base64url_uint( + public_numbers.y, + bit_length=key_obj.curve.key_size, + ).decode(), } if isinstance(key_obj, EllipticCurvePrivateKey): obj["d"] = to_base64url_uint( - key_obj.private_numbers().private_value + key_obj.private_numbers().private_value, + bit_length=key_obj.curve.key_size, ).decode() if as_dict: diff --git a/jwt/utils.py b/jwt/utils.py index d469139b..632f88b4 100644 --- a/jwt/utils.py +++ b/jwt/utils.py @@ -37,11 +37,11 @@ def base64url_encode(input: bytes) -> bytes: return base64.urlsafe_b64encode(input).replace(b"=", b"") -def to_base64url_uint(val: int) -> bytes: +def to_base64url_uint(val: int, *, bit_length: int | None = None) -> bytes: if val < 0: raise ValueError("Must be a positive integer") - int_bytes = bytes_from_int(val) + int_bytes = bytes_from_int(val, bit_length=bit_length) if len(int_bytes) == 0: int_bytes = b"\x00" @@ -63,13 +63,10 @@ def bytes_to_number(string: bytes) -> int: return int(binascii.b2a_hex(string), 16) -def bytes_from_int(val: int) -> bytes: - remaining = val - byte_length = 0 - - while remaining != 0: - remaining >>= 8 - byte_length += 1 +def bytes_from_int(val: int, *, bit_length: int | None = None) -> bytes: + if bit_length is None: + bit_length = val.bit_length() + byte_length = (bit_length + 7) // 8 return val.to_bytes(byte_length, "big", signed=False)