diff --git a/README.md b/README.md index 296e48a..883542c 100644 --- a/README.md +++ b/README.md @@ -146,7 +146,9 @@ The payload of messages will be by default something like: { "iss": "OneIssuer", "exp": 1234567890, - "iat": 1234567789 + "iat": 1234567789, + "email": "user@domain.com", + "scope": "read write" } ``` @@ -157,14 +159,41 @@ function that will enrich the payload and set the location to it in the ```python # settings.py +# Define the function to be called when creating a new JWT JWT_PAYLOAD_ENRICHER = 'myapp.jwt_utils.payload_enricher' +# Ovewrite all of the toolkit's default JWT claims with those provided by the function +# Useful if you want to design your own token payload (Default: False which +# performs a `dict().update(payload_enricher(...))`) +JWT_PAYLOAD_ENRICHER_OVERWRITE = False +``` +```python # myproject/myapp/jwt_utils.py -def payload_enricher(request): +def payload_enricher(**kwargs): + # Keyword Args: request, token_content, token_obj, current_claims + + # The Django HTTPRequest object + request = kwargs.pop('request', None) + + # Dictionary of the content of the Oauth response. Includes values like + # access_token, expires_in, token_type, refresh_token, scope + content = kwargs.pop('token_content', None) + + # The oauth2_provider access token (by default: + # oauth2_provider.models.AccessToken) + token = kwargs.pop('token_obj', None) + + # The automatically generated claims. This usually includes your + # JWT_ID_ATTRIBUTE and scope. This can be useful if you want to use + # JWT_PAYLOAD_ENRICHER_OVERWRITE mode. + current_claims = kwargs.pop('current_claims', None) + + # Values returned here must be serializable by json.dumps return { - 'sub': 'mysubject', + 'sub': token.user.pk, + 'preferred_username': token.user.username, ... } ``` diff --git a/oauth2_provider_jwt/utils.py b/oauth2_provider_jwt/utils.py index fbd8c8c..f34c88b 100644 --- a/oauth2_provider_jwt/utils.py +++ b/oauth2_provider_jwt/utils.py @@ -32,16 +32,23 @@ def generate_payload(issuer, expires_in, **extra_data): return payload -def encode_jwt(payload, headers=None): - """ +def encode_jwt(payload, issuer=None, headers=None): + """Sign and encode the provided ``payload`` as ``issuer``. :type payload: dict + :type issuer: str, None :type headers: dict, None :rtype: str """ + if not issuer and 'iss' in payload: + issuer = payload['iss'] + elif not issuer and 'iss' not in payload: + raise ValueError( + 'Unable to determine issuer. Token missing iss claim') + # RS256 in default, because hardcoded legacy algorithm = getattr(settings, 'JWT_ENC_ALGORITHM', 'RS256') - private_key_name = 'JWT_PRIVATE_KEY_{}'.format(payload['iss'].upper()) + private_key_name = 'JWT_PRIVATE_KEY_{}'.format(issuer.upper()) private_key = getattr(settings, private_key_name, None) if not private_key: raise ImproperlyConfigured('Missing setting {}'.format( @@ -51,9 +58,10 @@ def encode_jwt(payload, headers=None): return encoded.decode("utf-8") -def decode_jwt(jwt_value): +def decode_jwt(jwt_value, issuer=None): """ :type jwt_value: str + :type issuer: str, None """ try: headers_enc, payload_enc, verify_signature = jwt_value.split(".") @@ -63,8 +71,14 @@ def decode_jwt(jwt_value): payload_enc += '=' * (-len(payload_enc) % 4) # add padding payload = json.loads(base64.b64decode(payload_enc).decode("utf-8")) + if not issuer and 'iss' in payload: + issuer = payload['iss'] + elif not issuer and 'iss' not in payload: + raise ValueError( + 'Unable to determine issuer. Token missing iss claim') + algorithms = getattr(settings, 'JWT_JWS_ALGORITHMS', ['HS256', 'RS256']) - public_key_name = 'JWT_PUBLIC_KEY_{}'.format(payload['iss'].upper()) + public_key_name = 'JWT_PUBLIC_KEY_{}'.format(issuer.upper()) public_key = getattr(settings, public_key_name, None) if not public_key: raise ImproperlyConfigured('Missing setting {}'.format( diff --git a/oauth2_provider_jwt/views.py b/oauth2_provider_jwt/views.py index 5178c39..e8ab753 100644 --- a/oauth2_provider_jwt/views.py +++ b/oauth2_provider_jwt/views.py @@ -45,26 +45,38 @@ class TokenView(views.TokenView): def _get_access_token_jwt(self, request, content): extra_data = {} issuer = settings.JWT_ISSUER - payload_enricher = getattr(settings, 'JWT_PAYLOAD_ENRICHER', None) - if payload_enricher: - fn = import_string(payload_enricher) - extra_data = fn(request) + + token = get_access_token_model().objects.get( + token=content['access_token'] + ) if 'scope' in content: extra_data['scope'] = content['scope'] id_attribute = getattr(settings, 'JWT_ID_ATTRIBUTE', None) if id_attribute: - token = get_access_token_model().objects.get( - token=content['access_token'] - ) id_value = getattr(token.user, id_attribute, None) if not id_value: raise MissingIdAttribute() extra_data[id_attribute] = str(id_value) payload = generate_payload(issuer, content['expires_in'], **extra_data) - token = encode_jwt(payload) + + payload_enricher = getattr(settings, 'JWT_PAYLOAD_ENRICHER', None) + if payload_enricher: + fn = import_string(payload_enricher) + enriched_data = fn( + request=request, + token_content=content, + token_obj=token, + current_claims=payload) + + if getattr(settings, 'JWT_PAYLOAD_ENRICHER_OVERWRITE', False): + payload = enriched_data + else: + payload.update(enriched_data) + + token = encode_jwt(payload, issuer) return token @staticmethod diff --git a/tests/test_utils.py b/tests/test_utils.py index 7024d53..e4a664f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -90,6 +90,20 @@ def test_encode_jwt_rs256(self): json.loads(base64.b64decode(payload).decode("utf-8")), payload_in) + def test_encode_jwt_explicit_issuer(self): + payload_in = self._get_payload() + payload_in['iss'] = 'different-issuer' + encoded = utils.encode_jwt(payload_in, 'issuer') + self.assertIn(type(encoded).__name__, ('unicode', 'str')) + headers, payload, verify_signature = encoded.split(".") + self.assertDictEqual( + json.loads(base64.b64decode(headers)), + {"typ": "JWT", "alg": "RS256"}) + payload += '=' * (-len(payload) % 4) # add padding + self.assertEqual( + json.loads(base64.b64decode(payload).decode("utf-8")), + payload_in) + @override_settings(JWT_PRIVATE_KEY_ISSUER='test') @override_settings(JWT_ENC_ALGORITHM='HS256') def test_encode_jwt_hs256(self): @@ -143,6 +157,13 @@ def test_decode_jwt_rs256(self): payload_out = utils.decode_jwt(jwt_value) self.assertDictEqual(payload, payload_out) + def test_decode_jwt_explicit_issuer(self): + payload = self._get_payload() + payload['iss'] = 'different-issuer' + jwt_value = utils.encode_jwt(payload, 'issuer') + payload_out = utils.decode_jwt(jwt_value, 'issuer') + self.assertDictEqual(payload, payload_out) + @override_settings(JWT_PRIVATE_KEY_ISSUER='test') @override_settings(JWT_PUBLIC_KEY_ISSUER='test') @override_settings(JWT_ENC_ALGORITHM='HS256') diff --git a/tests/test_views.py b/tests/test_views.py index 7ebf8ff..c391ea9 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -36,9 +36,22 @@ def get_basic_auth_header(user, password): return auth_headers -def payload_enricher(request): +def payload_enricher(request, token_content, token_obj, current_claims): + # Explicit parameters used here to validate call + from django.http import HttpRequest + from oauth2_provider.models import AccessToken + + assert isinstance(request, HttpRequest), \ + 'payload enrichment function expecting HttpResponse object' + assert 'access_token' in token_content, \ + 'payload enrichment function expecting oauth token data' + assert isinstance(token_obj, AccessToken), \ + 'payload enrichment function expecting oauth2 AccessToken model' + assert 'iss' in current_claims, \ + 'payload enrichment function expecting default current_claims' + return { - 'sub': 'unique-user', + 'sub': token_obj.user.pk, } @@ -110,14 +123,25 @@ def test_get_token(self): content = json.loads(response.content.decode("utf-8")) jwt_token = content["access_token_jwt"] + jwt = self.decode_jwt(jwt_token) + self.assertEqual(content["token_type"], "Bearer") self.assertIn(type(jwt_token).__name__, ('unicode', 'str')) self.assertEqual(content["scope"], "read write") self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) - self.assertTrue('scope' in self.decode_jwt(jwt_token)) - self.assertEqual(self.decode_jwt(jwt_token).get('scope'), - 'read write') + + # Validate no unexpected data was included + self.assertEqual(len(jwt), 5) + + self.assertTrue('iss' in jwt) + self.assertTrue('exp' in jwt) + self.assertTrue('iat' in jwt) + self.assertTrue('scope' in jwt) + self.assertTrue('username' in jwt) + self.assertEqual(jwt.get('iss'), settings.JWT_ISSUER) + self.assertEqual(jwt.get('username'), self.test_user.username) + self.assertEqual(jwt.get('scope'), 'read write') def test_get_token_authorization_code(self): """ @@ -262,6 +286,8 @@ def test_do_not_get_token_missing_conf(self, mock_is_jwt_config_set): @override_settings( JWT_PAYLOAD_ENRICHER='tests.test_views.payload_enricher') + @override_settings( + JWT_PAYLOAD_ENRICHER_OVERWRITE=False) def test_get_enriched_jwt(self): token_request_data = { "grant_type": "password", @@ -276,9 +302,42 @@ def test_get_enriched_jwt(self): **auth_headers) content = json.loads(response.content.decode("utf-8")) access_token_jwt = content["access_token_jwt"] - self.assertTrue('sub' in self.decode_jwt(access_token_jwt)) - self.assertEqual(self.decode_jwt(access_token_jwt).get('sub'), - 'unique-user') + jwt = self.decode_jwt(access_token_jwt) + + # Validate the token was enriched + self.assertTrue('sub' in jwt) + self.assertEqual(jwt.get('sub'), self.test_user.pk) + + # Validate the token was extended rather than overwritten + self.assertTrue('username' in jwt) + self.assertEqual(jwt.get('username'), self.test_user.username) + + @override_settings( + JWT_PAYLOAD_ENRICHER='tests.test_views.payload_enricher') + @override_settings( + JWT_PAYLOAD_ENRICHER_OVERWRITE=True) + def test_overwrite_enriched_jwt(self): + token_request_data = { + "grant_type": "password", + "username": "test_user", + "password": "123456", + } + auth_headers = get_basic_auth_header(self.application.client_id, + self.application.client_secret) + + response = self.client.post( + reverse("oauth2_provider_jwt:token"), data=token_request_data, + **auth_headers) + content = json.loads(response.content.decode("utf-8")) + access_token_jwt = content["access_token_jwt"] + jwt = self.decode_jwt(access_token_jwt) + + # Validate the token was enriched + self.assertTrue('sub' in jwt) + self.assertEqual(jwt.get('sub'), self.test_user.pk) + + # Validate the token was overwritten + self.assertTrue(len(jwt) == 1) def test_get_custom_scope_in_jwt(self): token_request_data = {