From 4703f8780e532ef932f4a493b642ac5a6c1be53c Mon Sep 17 00:00:00 2001 From: Collin MacDonald Date: Mon, 10 Jun 2024 14:24:32 -0500 Subject: [PATCH] Handle load_pem_public_key ValueError (#952) * Handle load_pem_public_key ValueError * Add test for invalid key errors on prepare_key of an invalid key --------- Co-authored-by: MVRA <61065@icf.com> --- jwt/algorithms.py | 7 +++++-- tests/test_algorithms.py | 11 +++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/jwt/algorithms.py b/jwt/algorithms.py index c348a7dc..9be50b20 100644 --- a/jwt/algorithms.py +++ b/jwt/algorithms.py @@ -21,7 +21,7 @@ ) try: - from cryptography.exceptions import InvalidSignature + from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import padding @@ -343,7 +343,10 @@ def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys: RSAPrivateKey, load_pem_private_key(key_bytes, password=None) ) except ValueError: - return cast(RSAPublicKey, load_pem_public_key(key_bytes)) + try: + return cast(RSAPublicKey, load_pem_public_key(key_bytes)) + except (ValueError, UnsupportedAlgorithm): + raise InvalidKeyError("Could not parse the provided public key.") @overload @staticmethod diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py index 1a395527..337de96a 100644 --- a/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -1100,3 +1100,14 @@ def test_hmac_can_compute_digest(self): algo = HMACAlgorithm(HMACAlgorithm.SHA256) computed_hash = algo.compute_hash_digest(b"foo") assert computed_hash == foo_hash + + @crypto_required + def test_rsa_prepare_key_raises_invalid_key_error_on_invalid_pem(self): + algo = RSAAlgorithm(RSAAlgorithm.SHA256) + invalid_key = "invalid key" + + with pytest.raises(InvalidKeyError) as excinfo: + algo.prepare_key(invalid_key) + + # Check that the exception message is correct + assert "Could not parse the provided public key." in str(excinfo.value)