Skip to content
This repository has been archived by the owner on Dec 6, 2024. It is now read-only.

Add additional parameters to jwt enrichment call; Allow complete overwriting of token #30

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ The payload of messages will be by default something like:
{
"iss": "OneIssuer",
"exp": 1234567890,
"iat": 1234567789
"iat": 1234567789,
"email": "[email protected]",
"scope": "read write"
}
```

Expand All @@ -157,14 +159,41 @@ function that will enrich the payload and set the location to it in the
```python
# settings.py

# Define the function to be called when creating a new JWT
JWT_PAYLOAD_ENRICHER = 'myapp.jwt_utils.payload_enricher'

# Ovewrite all of the toolkit's default JWT claims with those provided by the function
# Useful if you want to design your own token payload (Default: False which
# performs a `dict().update(payload_enricher(...))`)
JWT_PAYLOAD_ENRICHER_OVERWRITE = False
```

```python
# myproject/myapp/jwt_utils.py

def payload_enricher(request):
def payload_enricher(**kwargs):
# Keyword Args: request, token_content, token_obj, current_claims

# The Django HTTPRequest object
request = kwargs.pop('request', None)

# Dictionary of the content of the Oauth response. Includes values like
# access_token, expires_in, token_type, refresh_token, scope
content = kwargs.pop('token_content', None)

# The oauth2_provider access token (by default:
# oauth2_provider.models.AccessToken)
token = kwargs.pop('token_obj', None)

# The automatically generated claims. This usually includes your
# JWT_ID_ATTRIBUTE and scope. This can be useful if you want to use
# JWT_PAYLOAD_ENRICHER_OVERWRITE mode.
current_claims = kwargs.pop('current_claims', None)

# Values returned here must be serializable by json.dumps
return {
'sub': 'mysubject',
'sub': token.user.pk,
'preferred_username': token.user.username,
...
}
```
Expand Down
24 changes: 19 additions & 5 deletions oauth2_provider_jwt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,23 @@ def generate_payload(issuer, expires_in, **extra_data):
return payload


def encode_jwt(payload, headers=None):
"""
def encode_jwt(payload, issuer=None, headers=None):
"""Sign and encode the provided ``payload`` as ``issuer``.
:type payload: dict
:type issuer: str, None
:type headers: dict, None
:rtype: str
"""
if not issuer and 'iss' in payload:
issuer = payload['iss']
elif not issuer and 'iss' not in payload:
raise ValueError(
'Unable to determine issuer. Token missing iss claim')

# RS256 in default, because hardcoded legacy
algorithm = getattr(settings, 'JWT_ENC_ALGORITHM', 'RS256')

private_key_name = 'JWT_PRIVATE_KEY_{}'.format(payload['iss'].upper())
private_key_name = 'JWT_PRIVATE_KEY_{}'.format(issuer.upper())
private_key = getattr(settings, private_key_name, None)
if not private_key:
raise ImproperlyConfigured('Missing setting {}'.format(
Expand All @@ -51,9 +58,10 @@ def encode_jwt(payload, headers=None):
return encoded.decode("utf-8")


def decode_jwt(jwt_value):
def decode_jwt(jwt_value, issuer=None):
"""
:type jwt_value: str
:type issuer: str, None
"""
try:
headers_enc, payload_enc, verify_signature = jwt_value.split(".")
Expand All @@ -63,8 +71,14 @@ def decode_jwt(jwt_value):
payload_enc += '=' * (-len(payload_enc) % 4) # add padding
payload = json.loads(base64.b64decode(payload_enc).decode("utf-8"))

if not issuer and 'iss' in payload:
issuer = payload['iss']
elif not issuer and 'iss' not in payload:
raise ValueError(
'Unable to determine issuer. Token missing iss claim')

algorithms = getattr(settings, 'JWT_JWS_ALGORITHMS', ['HS256', 'RS256'])
public_key_name = 'JWT_PUBLIC_KEY_{}'.format(payload['iss'].upper())
public_key_name = 'JWT_PUBLIC_KEY_{}'.format(issuer.upper())
public_key = getattr(settings, public_key_name, None)
if not public_key:
raise ImproperlyConfigured('Missing setting {}'.format(
Expand Down
28 changes: 20 additions & 8 deletions oauth2_provider_jwt/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,26 +45,38 @@ class TokenView(views.TokenView):
def _get_access_token_jwt(self, request, content):
extra_data = {}
issuer = settings.JWT_ISSUER
payload_enricher = getattr(settings, 'JWT_PAYLOAD_ENRICHER', None)
if payload_enricher:
fn = import_string(payload_enricher)
extra_data = fn(request)

token = get_access_token_model().objects.get(
token=content['access_token']
)

if 'scope' in content:
extra_data['scope'] = content['scope']

id_attribute = getattr(settings, 'JWT_ID_ATTRIBUTE', None)
if id_attribute:
token = get_access_token_model().objects.get(
token=content['access_token']
)
id_value = getattr(token.user, id_attribute, None)
if not id_value:
raise MissingIdAttribute()
extra_data[id_attribute] = str(id_value)

payload = generate_payload(issuer, content['expires_in'], **extra_data)
token = encode_jwt(payload)

payload_enricher = getattr(settings, 'JWT_PAYLOAD_ENRICHER', None)
if payload_enricher:
fn = import_string(payload_enricher)
enriched_data = fn(
request=request,
token_content=content,
token_obj=token,
current_claims=payload)

if getattr(settings, 'JWT_PAYLOAD_ENRICHER_OVERWRITE', False):
payload = enriched_data
else:
payload.update(enriched_data)

token = encode_jwt(payload, issuer)
return token

@staticmethod
Expand Down
21 changes: 21 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,20 @@ def test_encode_jwt_rs256(self):
json.loads(base64.b64decode(payload).decode("utf-8")),
payload_in)

def test_encode_jwt_explicit_issuer(self):
payload_in = self._get_payload()
payload_in['iss'] = 'different-issuer'
encoded = utils.encode_jwt(payload_in, 'issuer')
self.assertIn(type(encoded).__name__, ('unicode', 'str'))
headers, payload, verify_signature = encoded.split(".")
self.assertDictEqual(
json.loads(base64.b64decode(headers)),
{"typ": "JWT", "alg": "RS256"})
payload += '=' * (-len(payload) % 4) # add padding
self.assertEqual(
json.loads(base64.b64decode(payload).decode("utf-8")),
payload_in)

@override_settings(JWT_PRIVATE_KEY_ISSUER='test')
@override_settings(JWT_ENC_ALGORITHM='HS256')
def test_encode_jwt_hs256(self):
Expand Down Expand Up @@ -143,6 +157,13 @@ def test_decode_jwt_rs256(self):
payload_out = utils.decode_jwt(jwt_value)
self.assertDictEqual(payload, payload_out)

def test_decode_jwt_explicit_issuer(self):
payload = self._get_payload()
payload['iss'] = 'different-issuer'
jwt_value = utils.encode_jwt(payload, 'issuer')
payload_out = utils.decode_jwt(jwt_value, 'issuer')
self.assertDictEqual(payload, payload_out)

@override_settings(JWT_PRIVATE_KEY_ISSUER='test')
@override_settings(JWT_PUBLIC_KEY_ISSUER='test')
@override_settings(JWT_ENC_ALGORITHM='HS256')
Expand Down
75 changes: 67 additions & 8 deletions tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,22 @@ def get_basic_auth_header(user, password):
return auth_headers


def payload_enricher(request):
def payload_enricher(request, token_content, token_obj, current_claims):
# Explicit parameters used here to validate call
from django.http import HttpRequest
from oauth2_provider.models import AccessToken

assert isinstance(request, HttpRequest), \
'payload enrichment function expecting HttpResponse object'
assert 'access_token' in token_content, \
'payload enrichment function expecting oauth token data'
assert isinstance(token_obj, AccessToken), \
'payload enrichment function expecting oauth2 AccessToken model'
assert 'iss' in current_claims, \
'payload enrichment function expecting default current_claims'

return {
'sub': 'unique-user',
'sub': token_obj.user.pk,
}


Expand Down Expand Up @@ -110,14 +123,25 @@ def test_get_token(self):

content = json.loads(response.content.decode("utf-8"))
jwt_token = content["access_token_jwt"]
jwt = self.decode_jwt(jwt_token)

self.assertEqual(content["token_type"], "Bearer")
self.assertIn(type(jwt_token).__name__, ('unicode', 'str'))
self.assertEqual(content["scope"], "read write")
self.assertEqual(content["expires_in"],
oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS)
self.assertTrue('scope' in self.decode_jwt(jwt_token))
self.assertEqual(self.decode_jwt(jwt_token).get('scope'),
'read write')

# Validate no unexpected data was included
self.assertEqual(len(jwt), 5)

self.assertTrue('iss' in jwt)
self.assertTrue('exp' in jwt)
self.assertTrue('iat' in jwt)
self.assertTrue('scope' in jwt)
self.assertTrue('username' in jwt)
self.assertEqual(jwt.get('iss'), settings.JWT_ISSUER)
self.assertEqual(jwt.get('username'), self.test_user.username)
self.assertEqual(jwt.get('scope'), 'read write')

def test_get_token_authorization_code(self):
"""
Expand Down Expand Up @@ -262,6 +286,8 @@ def test_do_not_get_token_missing_conf(self, mock_is_jwt_config_set):

@override_settings(
JWT_PAYLOAD_ENRICHER='tests.test_views.payload_enricher')
@override_settings(
JWT_PAYLOAD_ENRICHER_OVERWRITE=False)
def test_get_enriched_jwt(self):
token_request_data = {
"grant_type": "password",
Expand All @@ -276,9 +302,42 @@ def test_get_enriched_jwt(self):
**auth_headers)
content = json.loads(response.content.decode("utf-8"))
access_token_jwt = content["access_token_jwt"]
self.assertTrue('sub' in self.decode_jwt(access_token_jwt))
self.assertEqual(self.decode_jwt(access_token_jwt).get('sub'),
'unique-user')
jwt = self.decode_jwt(access_token_jwt)

# Validate the token was enriched
self.assertTrue('sub' in jwt)
self.assertEqual(jwt.get('sub'), self.test_user.pk)

# Validate the token was extended rather than overwritten
self.assertTrue('username' in jwt)
self.assertEqual(jwt.get('username'), self.test_user.username)

@override_settings(
JWT_PAYLOAD_ENRICHER='tests.test_views.payload_enricher')
@override_settings(
JWT_PAYLOAD_ENRICHER_OVERWRITE=True)
def test_overwrite_enriched_jwt(self):
token_request_data = {
"grant_type": "password",
"username": "test_user",
"password": "123456",
}
auth_headers = get_basic_auth_header(self.application.client_id,
self.application.client_secret)

response = self.client.post(
reverse("oauth2_provider_jwt:token"), data=token_request_data,
**auth_headers)
content = json.loads(response.content.decode("utf-8"))
access_token_jwt = content["access_token_jwt"]
jwt = self.decode_jwt(access_token_jwt)

# Validate the token was enriched
self.assertTrue('sub' in jwt)
self.assertEqual(jwt.get('sub'), self.test_user.pk)

# Validate the token was overwritten
self.assertTrue(len(jwt) == 1)

def test_get_custom_scope_in_jwt(self):
token_request_data = {
Expand Down