diff --git a/jwt/utils.py b/jwt/utils.py index 81c5ee41..2a172e1b 100644 --- a/jwt/utils.py +++ b/jwt/utils.py @@ -133,24 +133,26 @@ def is_pem_format(key: bytes) -> bool: # Based on https://github.com/pyca/cryptography/blob/bcb70852d577b3f490f015378c75cba74986297b/src/cryptography/hazmat/primitives/serialization/ssh.py#L40-L46 _CERT_SUFFIX = b"-cert-v01@openssh.com" _SSH_PUBKEY_RC = re.compile(rb"\A(\S+)[ \t]+(\S+)") -_SSH_KEY_FORMATS = [ +_SSH_KEY_FORMATS = ( b"ssh-ed25519", b"ssh-rsa", b"ssh-dss", b"ecdsa-sha2-nistp256", b"ecdsa-sha2-nistp384", b"ecdsa-sha2-nistp521", -] +) def is_ssh_key(key: bytes) -> bool: - if any(string_value in key for string_value in _SSH_KEY_FORMATS): + if key.startswith(_SSH_KEY_FORMATS): return True - ssh_pubkey_match = _SSH_PUBKEY_RC.match(key) - if ssh_pubkey_match: - key_type = ssh_pubkey_match.group(1) - if _CERT_SUFFIX == key_type[-len(_CERT_SUFFIX) :]: - return True + # Avoid the regex is the _CERT_SUFFIX is not in the key + if _CERT_SUFFIX in key: + ssh_pubkey_match = _SSH_PUBKEY_RC.match(key) + if ssh_pubkey_match: + key_type = ssh_pubkey_match.group(1) + if _CERT_SUFFIX == key_type[-len(_CERT_SUFFIX) :]: + return True return False diff --git a/tests/test_utils.py b/tests/test_utils.py index 122dcb4e..a83aff02 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,6 @@ import pytest -from jwt.utils import force_bytes, from_base64url_uint, to_base64url_uint +from jwt.utils import force_bytes, from_base64url_uint, is_ssh_key, to_base64url_uint @pytest.mark.parametrize( @@ -37,3 +37,9 @@ def test_from_base64url_uint(inputval, expected): def test_force_bytes_raises_error_on_invalid_object(): with pytest.raises(TypeError): force_bytes({}) # type: ignore[arg-type] + + +def test_is_ssh_key(): + assert is_ssh_key(b"ecdsa-sha2-nistp256 any") is True + assert is_ssh_key(b"not a ssh key") is False + assert is_ssh_key(b"any-cert-v01@openssh.com any") is True