Skip to content

Commit

Permalink
Remove ecdsa dependency (#403)
Browse files Browse the repository at this point in the history
* refactored jwt module to remove python-jose

* updated tests for new jwt module

---------

Co-authored-by: Bryan Apellanes <[email protected]>
  • Loading branch information
agburch and bryanapellanes-okta authored May 23, 2024
1 parent 4227d17 commit 51e9c96
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 41 deletions.
61 changes: 33 additions & 28 deletions okta/jwt.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import json
from Cryptodome.PublicKey import RSA
from ast import literal_eval
import jose.jwk as jwk
import jose.jwt as jwt
import os
import time
import uuid
import os

from ast import literal_eval
from Cryptodome.PublicKey import RSA
from jwcrypto.jwk import JWK, InvalidJWKType
from jwt import encode as jwt_encode


class JWT():
Expand Down Expand Up @@ -63,32 +64,36 @@ def get_PEM_JWK(private_key):
# if string repr, convert to dict object
if isinstance(private_key, str):
private_key = literal_eval(private_key)
# Create JWK using dict obj
my_jwk = jwk.construct(private_key, JWT.HASH_ALGORITHM)
# remove whitespace from key vaules
private_key = {k: ''.join(private_key[k].split()) for k in private_key}
# ensure private_key is JSON formatted
try:
json.loads(private_key)
except TypeError:
private_key = json.dumps(private_key)
try:
my_jwk = JWK.from_json(private_key)
except InvalidJWKType:
raise ValueError(
"JWK given is of the wrong type")
else: # it's a PEM
# check for filepath or explicit private key
if isinstance(private_key, (str, bytes, os.PathLike)) and os.path.exists(private_key):
# open file if exists and import key
# open file if exists and read
pem_file = open(private_key, 'r')
my_pem = RSA.import_key(pem_file.read())
private_key = pem_file.read()
pem_file.close()
else:
# convert given string to bytes and import key
private_key_bytes = bytes(private_key, 'ascii')
my_pem = RSA.import_key(private_key_bytes)

if not my_pem:
# return error if import failed
return (None, ValueError(
"RSA Private Key given is of the wrong type"))

if my_jwk: # was JWK provided
# get PEM using JWK
pem_bytes = my_jwk.to_pem(JWT.PEM_FORMAT)
my_pem = RSA.import_key(pem_bytes)
else: # was pem provided
# get JWK using PEM
my_jwk = jwk.construct(my_pem.export_key(), JWT.HASH_ALGORITHM)
# remove leading whitespaces from each line
my_pem = '\n'.join([line.strip() for line in private_key.splitlines()])
my_pem = bytes(my_pem, 'ascii')
try:
my_jwk = JWK.from_pem(my_pem)
except ValueError:
raise ValueError(
"RSA Private Key given is of the wrong type")

my_pem = my_jwk.export_to_pem(private_key=True, password=None)
my_pem = RSA.import_key(my_pem)

return (my_pem, my_jwk)

Expand All @@ -108,7 +113,7 @@ def create_token(org_url, client_id, private_key, kid=None):
str: Generated JWT
"""
# Generate PEM and JWK
my_pem, my_jwk = JWT.get_PEM_JWK(private_key)
my_pem, _ = JWT.get_PEM_JWK(private_key)
# Get current time and expiry time for token
issued_time = int(time.time())
expiry_time = issued_time + JWT.ONE_HOUR
Expand Down Expand Up @@ -142,5 +147,5 @@ def create_token(org_url, client_id, private_key, kid=None):
if "kid" in headers:
del headers["kid"]

token = jwt.encode(claims, my_jwk.to_dict(), JWT.HASH_ALGORITHM, headers=headers)
token = jwt_encode(claims, my_pem.export_key(), JWT.HASH_ALGORITHM, headers)
return token
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ pyyaml
xmltodict
yarl
pycryptodomex
python-jose[cryptography]
jwcrypto
pyjwt
aenum
pydash
flake8
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def get_version():
"xmltodict",
"yarl",
"pycryptodomex",
"python-jose",
"jwcrypto",
"pyjwt",
"aenum==3.1.11",
"pydash"
]
Expand Down
3 changes: 3 additions & 0 deletions tests/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,3 +416,6 @@ def mock_next_link(self_url: URL):
KLElmMvzocvFaWKvup_a3vPaBi6y4K5kBiq60o-IDMGQ''',
"kid": "5ashWt3LP1zkYwMGbfMsVizRfx52QTyky4GTHd9MykE"
}

SAMPLE_INVALID_JWK = {'foo':'bar'}
SAMPLE_INVALID_RSA = 'foobar'
16 changes: 8 additions & 8 deletions tests/unit/test_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,20 @@


def test_private_key_with_kid_in_private_key(mocker):
mocked_encode = mocker.patch('jose.jwt.encode')
mocked_encode = mocker.patch('okta.jwt.jwt_encode')
JWT.create_token("test.com", "test-client-id", mocks.SAMPLE_JWK_WITH_KID)
expected_kid = mocks.SAMPLE_JWK_WITH_KID["kid"]
_, kwargs = mocked_encode.call_args
args = mocked_encode.call_args.args
mocked_encode.assert_called_once()
assert "kid" in kwargs["headers"]
assert kwargs["headers"]["kid"] == expected_kid
assert "kid" in args[-1]
assert args[-1]["kid"] == expected_kid


def test_private_key_with_kid_in_config(mocker):
mocked_encode = mocker.patch('jose.jwt.encode')
mocked_encode = mocker.patch('okta.jwt.jwt_encode')
expected_kid = "test-kid"
JWT.create_token("test.com", "test-client-id", mocks.SAMPLE_JWK, kid=expected_kid)
_, kwargs = mocked_encode.call_args
args = mocked_encode.call_args.args
mocked_encode.assert_called_once()
assert "kid" in kwargs["headers"]
assert kwargs["headers"]["kid"] == expected_kid
assert "kid" in args[-1]
assert args[-1]["kid"] == expected_kid
13 changes: 10 additions & 3 deletions tests/unit/test_oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_private_key_PEM_JWK_dict(jwk_input):
generated_pem, generated_jwk = JWT.get_PEM_JWK(jwk_input)

assert generated_pem is not None and generated_jwk is not None
assert not generated_jwk.is_public()
assert generated_jwk.has_private


def test_private_key_PEM_JWK_file(fs):
Expand All @@ -24,11 +24,18 @@ def test_private_key_PEM_JWK_file(fs):
generated_pem, generated_jwk = JWT.get_PEM_JWK(file_path)

assert generated_pem is not None and generated_jwk is not None
assert not generated_jwk.is_public()
assert generated_jwk.has_private


def test_private_key_PEM_JWK_explicit_string():
generated_pem, generated_jwk = JWT.get_PEM_JWK(mocks.SAMPLE_RSA)

assert generated_pem is not None and generated_jwk is not None
assert not generated_jwk.is_public()
assert generated_jwk.has_private


@pytest.mark.parametrize("private_key",
[mocks.SAMPLE_INVALID_JWK, str(mocks.SAMPLE_INVALID_JWK), mocks.SAMPLE_INVALID_RSA])
def test_invalid_private_key_PEM_JWK(private_key):
with pytest.raises(ValueError):
generated_pem, generated_jwk = JWT.get_PEM_JWK(private_key)

0 comments on commit 51e9c96

Please sign in to comment.