Skip to content

Commit

Permalink
Add coverage and improve performance of is_ssh_key
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco committed Jan 20, 2024
1 parent f86b8b6 commit 38dd800
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 9 deletions.
18 changes: 10 additions & 8 deletions jwt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,24 +133,26 @@ def is_pem_format(key: bytes) -> bool:
# Based on https://github.com/pyca/cryptography/blob/bcb70852d577b3f490f015378c75cba74986297b/src/cryptography/hazmat/primitives/serialization/ssh.py#L40-L46
_CERT_SUFFIX = b"[email protected]"
_SSH_PUBKEY_RC = re.compile(rb"\A(\S+)[ \t]+(\S+)")
_SSH_KEY_FORMATS = [
_SSH_KEY_FORMATS = (
b"ssh-ed25519",
b"ssh-rsa",
b"ssh-dss",
b"ecdsa-sha2-nistp256",
b"ecdsa-sha2-nistp384",
b"ecdsa-sha2-nistp521",
]
)


def is_ssh_key(key: bytes) -> bool:
if any(string_value in key for string_value in _SSH_KEY_FORMATS):
if key.startswith(_SSH_KEY_FORMATS):
return True

ssh_pubkey_match = _SSH_PUBKEY_RC.match(key)
if ssh_pubkey_match:
key_type = ssh_pubkey_match.group(1)
if _CERT_SUFFIX == key_type[-len(_CERT_SUFFIX) :]:
return True
# Avoid the regex is the _CERT_SUFFIX is not in the key
if _CERT_SUFFIX in key:
ssh_pubkey_match = _SSH_PUBKEY_RC.match(key)
if ssh_pubkey_match:
key_type = ssh_pubkey_match.group(1)
if _CERT_SUFFIX == key_type[-len(_CERT_SUFFIX) :]:
return True

return False
8 changes: 7 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from jwt.utils import force_bytes, from_base64url_uint, to_base64url_uint
from jwt.utils import force_bytes, from_base64url_uint, is_ssh_key, to_base64url_uint


@pytest.mark.parametrize(
Expand Down Expand Up @@ -37,3 +37,9 @@ def test_from_base64url_uint(inputval, expected):
def test_force_bytes_raises_error_on_invalid_object():
with pytest.raises(TypeError):
force_bytes({}) # type: ignore[arg-type]


def test_is_ssh_key():
assert is_ssh_key(b"ecdsa-sha2-nistp256 any") is True
assert is_ssh_key(b"not a ssh key") is False
assert is_ssh_key(b"[email protected] any") is True

0 comments on commit 38dd800

Please sign in to comment.