Skip to content

Commit

Permalink
fix: ensure that LTI 1.3 launches work
Browse files Browse the repository at this point in the history
  • Loading branch information
alangsto committed Jan 15, 2025
1 parent 46cfc92 commit c1ce35b
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 16 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ Please See the `releases tab <https://github.com/openedx/xblock-lti-consumer/rel
Unreleased
~~~~~~~~~~

9.13.1 - 2025-01-15
-------------------
* Fix broken LTI 1.3 launch

9.13.0 - 2025-01-08
-------------------
* Removed pyjwkset package and replace with pyjwt package
Expand Down
2 changes: 1 addition & 1 deletion lti_consumer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
from .apps import LTIConsumerApp
from .lti_xblock import LtiConsumerXBlock

__version__ = '9.13.0'
__version__ = '9.13.1'
15 changes: 5 additions & 10 deletions lti_consumer/lti_1p3/key_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(self, public_key=None, keyset_url=None):
)
raise exceptions.InvalidRsaKey() from err

def _get_keyset(self, kid=None):
def _get_keyset(self):
"""
Get keyset from available sources.
Expand All @@ -86,13 +86,6 @@ def _get_keyset(self, kid=None):
raise exceptions.NoSuitableKeys() from err
keyset.extend(keys.keys)

if self.public_key and kid:
# Fill in key id of stored key.
# This is needed because if the JWS is signed with a
# key with a kid, pyjwkest doesn't match them with
# keys without kid (kid=None) and fails verification
self.public_key.kid = kid

if self.public_key:
# Add to keyset
keyset.append(self.public_key)
Expand Down Expand Up @@ -185,7 +178,7 @@ def encode_and_sign(self, message, expiration=None):

# The class instance that sets up the signing operation
# An RS 256 key is required for LTI 1.3
return jwt.encode(_message, self.key.key, algorithm="RS256")
return jwt.encode(_message, self.key.key, algorithm="RS256", headers={"kid": self.key.key_id})

def get_public_jwk(self):
"""
Expand All @@ -197,7 +190,9 @@ def get_public_jwk(self):
if self.key:
algo_obj = jwt.get_algorithm_by_name('RS256')
public_key = algo_obj.prepare_key(self.key.key).public_key()
jwk['keys'].append(json.loads(algo_obj.to_jwk(public_key)))
public_jwk = json.loads(algo_obj.to_jwk(public_key))
public_jwk['kid'] = self.key.key_id
jwk['keys'].append(public_jwk)
return jwk

def validate_and_decode(self, token, iss=None, aud=None, exp=True):
Expand Down
1 change: 1 addition & 0 deletions lti_consumer/lti_1p3/tests/test_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def _decode_token(self, token):
keyset = PyJWKSet.from_dict(public_keyset).keys

for obj in keyset:
self.assertEqual(obj.key_id, RSA_KEY_ID)
message = jwt.decode(
token,
key=obj.key,
Expand Down
18 changes: 14 additions & 4 deletions lti_consumer/lti_1p3/tests/test_key_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import ddt
import jwt
from Cryptodome.PublicKey import RSA
from cryptography.hazmat.primitives import serialization
from django.test.testcases import TestCase
from jwt.api_jwk import PyJWK

Expand Down Expand Up @@ -58,6 +59,11 @@ def test_encode_and_sign(self):
message
)

self.assertEqual(
jwt.get_unverified_header(signed_token)['kid'],
self.rsa_key_id
)

# pylint: disable=unused-argument
@patch('time.time', return_value=1000)
def test_encode_and_sign_with_exp(self, mock_time):
Expand Down Expand Up @@ -233,16 +239,20 @@ def test_get_empty_keyset(self):

def test_get_keyset_with_pub_key(self):
"""
Check that getting a keyset from a RSA key.
Check that if there is a public key, it is returned in the keyset.
"""
self._setup_key_handler()

# pylint: disable=protected-access
keyset = self.key_handler._get_keyset(kid=self.rsa_key_id)
keyset = self.key_handler._get_keyset()
self.assertEqual(len(keyset), 1)
public_key = keyset[0].key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
)[:-1]
self.assertEqual(
keyset[0].kid,
self.rsa_key_id
public_key,
self.public_key
)

def test_validate_and_decode(self):
Expand Down
1 change: 0 additions & 1 deletion lti_consumer/plugin/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,6 @@ def access_token_endpoint(
except Exception: # pylint: disable=broad-except
exc_info = sys.exc_info()

# import pdb; pdb.set_trace()
# Handle errors and return a proper response
if exc_info[0] == MissingRequiredClaim:
# Missing request attributes
Expand Down

0 comments on commit c1ce35b

Please sign in to comment.