Skip to content

Commit

Permalink
Add additional parameters to jwt enrichment call
Browse files Browse the repository at this point in the history
- Adds Oauth token `content`, oauth2_provider's `token_obj` model, and
  django-oauth-toolkit-jwt's `current_claims` as parameters to
  `JWT_PAYLOAD_ENRICHER`. This will give the user more control over the
  data that is included in their JWT
- Add `JWT_PAYLOAD_ENRICHER_OVERWRITE` setting (default is
  False). Setting this to true will allow the user total control over
  the claims included.

Fixes humanitec#29
  • Loading branch information
com4 committed Jan 8, 2021
1 parent 9c93dc4 commit ab46e22
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 24 deletions.
35 changes: 32 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ The payload of messages will be by default something like:
{
"iss": "OneIssuer",
"exp": 1234567890,
"iat": 1234567789
"iat": 1234567789,
"email": "[email protected]",
"scope": "read write"
}
```

Expand All @@ -157,14 +159,41 @@ function that will enrich the payload and set the location to it in the
```python
# settings.py

# Define the function to be called when creating a new JWT
JWT_PAYLOAD_ENRICHER = 'myapp.jwt_utils.payload_enricher'

# Ovewrite all of the toolkit's default JWT claims with those provided by the function
# Useful if you want to design your own token payload (Default: False which
# performs a `dict().update(payload_enricher(...))`)
JWT_PAYLOAD_ENRICHER_OVERWRITE = False
```

```python
# myproject/myapp/jwt_utils.py

def payload_enricher(request):
def payload_enricher(**kwargs):
# Keyword Args: request, token_content, token_obj, current_claims

# The Django HTTPRequest object
request = kwargs.pop('request', None)

# Dictionary of the content of the Oauth response. Includes values like
# access_token, expires_in, token_type, refresh_token, scope
content = kwargs.pop('token_content', None)

# The oauth2_provider access token (by default:
# oauth2_provider.models.AccessToken)
token = kwargs.pop('token_obj', None)

# The automatically generated claims. This usually includes your
# JWT_ID_ATTRIBUTE and scope. This can be useful if you want to use
# JWT_PAYLOAD_ENRICHER_OVERWRITE mode.
current_claims = kwargs.pop('current_claims', None)

# Values returned here must be serializable by json.dumps
return {
'sub': 'mysubject',
'sub': token.user.pk,
'preferred_username': token.user.username,
...
}
```
Expand Down
24 changes: 19 additions & 5 deletions oauth2_provider_jwt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,23 @@ def generate_payload(issuer, expires_in, **extra_data):
return payload


def encode_jwt(payload, headers=None):
"""
def encode_jwt(payload, issuer=None, headers=None):
"""Sign and encode the provided ``payload`` as ``issuer``.
:type payload: dict
:type issuer: str, None
:type headers: dict, None
:rtype: str
"""
if not issuer and 'iss' in payload:
issuer = payload['iss']
elif not issuer and 'iss' not in payload:
raise ValueError(
'Unable to determine issuer. Token missing iss claim')

# RS256 in default, because hardcoded legacy
algorithm = getattr(settings, 'JWT_ENC_ALGORITHM', 'RS256')

private_key_name = 'JWT_PRIVATE_KEY_{}'.format(payload['iss'].upper())
private_key_name = 'JWT_PRIVATE_KEY_{}'.format(issuer.upper())
private_key = getattr(settings, private_key_name, None)
if not private_key:
raise ImproperlyConfigured('Missing setting {}'.format(
Expand All @@ -51,9 +58,10 @@ def encode_jwt(payload, headers=None):
return encoded.decode("utf-8")


def decode_jwt(jwt_value):
def decode_jwt(jwt_value, issuer=None):
"""
:type jwt_value: str
:type issuer: str, None
"""
try:
headers_enc, payload_enc, verify_signature = jwt_value.split(".")
Expand All @@ -63,8 +71,14 @@ def decode_jwt(jwt_value):
payload_enc += '=' * (-len(payload_enc) % 4) # add padding
payload = json.loads(base64.b64decode(payload_enc).decode("utf-8"))

if not issuer and 'iss' in payload:
issuer = payload['iss']
elif not issuer and 'iss' not in payload:
raise ValueError(
'Unable to determine issuer. Token missing iss claim')

algorithms = getattr(settings, 'JWT_JWS_ALGORITHMS', ['HS256', 'RS256'])
public_key_name = 'JWT_PUBLIC_KEY_{}'.format(payload['iss'].upper())
public_key_name = 'JWT_PUBLIC_KEY_{}'.format(issuer.upper())
public_key = getattr(settings, public_key_name, None)
if not public_key:
raise ImproperlyConfigured('Missing setting {}'.format(
Expand Down
28 changes: 20 additions & 8 deletions oauth2_provider_jwt/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,26 +45,38 @@ class TokenView(views.TokenView):
def _get_access_token_jwt(self, request, content):
extra_data = {}
issuer = settings.JWT_ISSUER
payload_enricher = getattr(settings, 'JWT_PAYLOAD_ENRICHER', None)
if payload_enricher:
fn = import_string(payload_enricher)
extra_data = fn(request)

token = get_access_token_model().objects.get(
token=content['access_token']
)

if 'scope' in content:
extra_data['scope'] = content['scope']

id_attribute = getattr(settings, 'JWT_ID_ATTRIBUTE', None)
if id_attribute:
token = get_access_token_model().objects.get(
token=content['access_token']
)
id_value = getattr(token.user, id_attribute, None)
if not id_value:
raise MissingIdAttribute()
extra_data[id_attribute] = str(id_value)

payload = generate_payload(issuer, content['expires_in'], **extra_data)
token = encode_jwt(payload)

payload_enricher = getattr(settings, 'JWT_PAYLOAD_ENRICHER', None)
if payload_enricher:
fn = import_string(payload_enricher)
enriched_data = fn(
request=request,
token_content=content,
token_obj=token,
current_claims=payload)

if getattr(settings, 'JWT_PAYLOAD_ENRICHER_OVERWRITE', False):
payload = enriched_data
else:
payload.update(enriched_data)

token = encode_jwt(payload, issuer)
return token

@staticmethod
Expand Down
21 changes: 21 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,20 @@ def test_encode_jwt_rs256(self):
json.loads(base64.b64decode(payload).decode("utf-8")),
payload_in)

def test_encode_jwt_explicit_issuer(self):
payload_in = self._get_payload()
payload_in['iss'] = 'different-issuer'
encoded = utils.encode_jwt(payload_in, 'issuer')
self.assertIn(type(encoded).__name__, ('unicode', 'str'))
headers, payload, verify_signature = encoded.split(".")
self.assertDictEqual(
json.loads(base64.b64decode(headers)),
{"typ": "JWT", "alg": "RS256"})
payload += '=' * (-len(payload) % 4) # add padding
self.assertEqual(
json.loads(base64.b64decode(payload).decode("utf-8")),
payload_in)

@override_settings(JWT_PRIVATE_KEY_ISSUER='test')
@override_settings(JWT_ENC_ALGORITHM='HS256')
def test_encode_jwt_hs256(self):
Expand Down Expand Up @@ -143,6 +157,13 @@ def test_decode_jwt_rs256(self):
payload_out = utils.decode_jwt(jwt_value)
self.assertDictEqual(payload, payload_out)

def test_decode_jwt_explicit_issuer(self):
payload = self._get_payload()
payload['iss'] = 'different-issuer'
jwt_value = utils.encode_jwt(payload, 'issuer')
payload_out = utils.decode_jwt(jwt_value, 'issuer')
self.assertDictEqual(payload, payload_out)

@override_settings(JWT_PRIVATE_KEY_ISSUER='test')
@override_settings(JWT_PUBLIC_KEY_ISSUER='test')
@override_settings(JWT_ENC_ALGORITHM='HS256')
Expand Down
75 changes: 67 additions & 8 deletions tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,22 @@ def get_basic_auth_header(user, password):
return auth_headers


def payload_enricher(request):
def payload_enricher(request, token_content, token_obj, current_claims):
# Explicit parameters used here to validate call
from django.http import HttpRequest
from oauth2_provider.models import AccessToken

assert isinstance(request, HttpRequest), \
'payload enrichment function expecting HttpResponse object'
assert 'access_token' in token_content, \
'payload enrichment function expecting oauth token data'
assert isinstance(token_obj, AccessToken), \
'payload enrichment function expecting oauth2 AccessToken model'
assert 'iss' in current_claims, \
'payload enrichment function expecting default current_claims'

return {
'sub': 'unique-user',
'sub': token_obj.user.pk,
}


Expand Down Expand Up @@ -110,14 +123,25 @@ def test_get_token(self):

content = json.loads(response.content.decode("utf-8"))
jwt_token = content["access_token_jwt"]
jwt = self.decode_jwt(jwt_token)

self.assertEqual(content["token_type"], "Bearer")
self.assertIn(type(jwt_token).__name__, ('unicode', 'str'))
self.assertEqual(content["scope"], "read write")
self.assertEqual(content["expires_in"],
oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS)
self.assertTrue('scope' in self.decode_jwt(jwt_token))
self.assertEqual(self.decode_jwt(jwt_token).get('scope'),
'read write')

# Validate no unexpected data was included
self.assertEqual(len(jwt), 5)

self.assertTrue('iss' in jwt)
self.assertTrue('exp' in jwt)
self.assertTrue('iat' in jwt)
self.assertTrue('scope' in jwt)
self.assertTrue('username' in jwt)
self.assertEqual(jwt.get('iss'), settings.JWT_ISSUER)
self.assertEqual(jwt.get('username'), self.test_user.username)
self.assertEqual(jwt.get('scope'), 'read write')

def test_get_token_authorization_code(self):
"""
Expand Down Expand Up @@ -262,6 +286,8 @@ def test_do_not_get_token_missing_conf(self, mock_is_jwt_config_set):

@override_settings(
JWT_PAYLOAD_ENRICHER='tests.test_views.payload_enricher')
@override_settings(
JWT_PAYLOAD_ENRICHER_OVERWRITE=False)
def test_get_enriched_jwt(self):
token_request_data = {
"grant_type": "password",
Expand All @@ -276,9 +302,42 @@ def test_get_enriched_jwt(self):
**auth_headers)
content = json.loads(response.content.decode("utf-8"))
access_token_jwt = content["access_token_jwt"]
self.assertTrue('sub' in self.decode_jwt(access_token_jwt))
self.assertEqual(self.decode_jwt(access_token_jwt).get('sub'),
'unique-user')
jwt = self.decode_jwt(access_token_jwt)

# Validate the token was enriched
self.assertTrue('sub' in jwt)
self.assertEqual(jwt.get('sub'), self.test_user.pk)

# Validate the token was extended rather than overwritten
self.assertTrue('username' in jwt)
self.assertEqual(jwt.get('username'), self.test_user.username)

@override_settings(
JWT_PAYLOAD_ENRICHER='tests.test_views.payload_enricher')
@override_settings(
JWT_PAYLOAD_ENRICHER_OVERWRITE=True)
def test_overwrite_enriched_jwt(self):
token_request_data = {
"grant_type": "password",
"username": "test_user",
"password": "123456",
}
auth_headers = get_basic_auth_header(self.application.client_id,
self.application.client_secret)

response = self.client.post(
reverse("oauth2_provider_jwt:token"), data=token_request_data,
**auth_headers)
content = json.loads(response.content.decode("utf-8"))
access_token_jwt = content["access_token_jwt"]
jwt = self.decode_jwt(access_token_jwt)

# Validate the token was enriched
self.assertTrue('sub' in jwt)
self.assertEqual(jwt.get('sub'), self.test_user.pk)

# Validate the token was overwritten
self.assertTrue(len(jwt) == 1)

def test_get_custom_scope_in_jwt(self):
token_request_data = {
Expand Down

0 comments on commit ab46e22

Please sign in to comment.