Skip to content

Commit

Permalink
Allow changing subject claim when decoding
Browse files Browse the repository at this point in the history
Related to issue vimalloc#65
  • Loading branch information
psafont committed Jul 12, 2017
1 parent 17c3254 commit df68d65
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 22 deletions.
3 changes: 3 additions & 0 deletions docs/options.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ General Options:
such as ``RS*`` or ``ES*``. PEM format expected.
``JWT_PRIVATE_KEY`` The private key needed for asymmetric based signing algorithms,
such as ``RS*`` or ``ES*``. PEM format expected.
``JWT_IDENTITY_CLAIM`` Claim in the tokens that is used on decoding as source of identity.
For interoperativity, the JWT RFC recommends using ``'sub'``.
Defaults to ``'identity'``.
================================= =========================================


Expand Down
4 changes: 4 additions & 0 deletions flask_jwt_extended/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,10 @@ def cookie_max_age(self):
# seconds a long ways in the future
return None if self.session_cookie else 2147483647 # 2^31

@property
def identity_claim(self):
return current_app.config['JWT_IDENTITY_CLAIM']

config = _Config()


2 changes: 2 additions & 0 deletions flask_jwt_extended/jwt_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ def _set_default_configuration_options(app):
app.config.setdefault('JWT_BLACKLIST_ENABLED', False)
app.config.setdefault('JWT_BLACKLIST_TOKEN_CHECKS', ['access', 'refresh'])

app.config.setdefault('JWT_IDENTITY_CLAIM', 'identity')

def user_claims_loader(self, callback):
"""
This sets the callback method for adding custom user claims to a JWT.
Expand Down
7 changes: 4 additions & 3 deletions flask_jwt_extended/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def encode_refresh_token(identity, secret, algorithm, expires_delta, csrf):
return _encode_jwt(token_data, expires_delta, secret, algorithm)


def decode_jwt(encoded_token, secret, algorithm, csrf):
def decode_jwt(encoded_token, secret, algorithm, csrf, identity_claim):
"""
Decodes an encoded JWT
Expand All @@ -85,6 +85,7 @@ def decode_jwt(encoded_token, secret, algorithm, csrf):
:param algorithm: Algorithm used to encode the JWT
:param csrf: If this token is expected to have a CSRF double submit
value present (boolean)
:param identity_claim: expected claim that is used to identify the subject
:return: Dictionary containing contents of the JWT
"""
# This call verifies the ext, iat, and nbf claims
Expand All @@ -93,8 +94,8 @@ def decode_jwt(encoded_token, secret, algorithm, csrf):
# Make sure that any custom claims we expect in the token are present
if 'jti' not in data:
raise JWTDecodeError("Missing claim: jti")
if 'identity' not in data:
raise JWTDecodeError("Missing claim: identity")
if identity_claim not in data:
raise JWTDecodeError("Missing claim: {}".format(identity_claim))
if 'type' not in data or data['type'] not in ('refresh', 'access'):
raise JWTDecodeError("Missing or invalid claim: type")
if data['type'] == 'access':
Expand Down
13 changes: 10 additions & 3 deletions flask_jwt_extended/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def get_jwt_identity():
Returns the identity of the JWT in this context. If no JWT is present,
None is returned.
"""
return get_raw_jwt().get('identity', None)
return get_raw_jwt().get(config.identity_claim, None)


def get_jwt_claims():
Expand Down Expand Up @@ -63,7 +63,8 @@ def decode_token(encoded_token):
encoded_token=encoded_token,
secret=config.decode_key,
algorithm=config.algorithm,
csrf=config.csrf_protect
csrf=config.csrf_protect,
identity_claim=config.identity_claim
)


Expand Down Expand Up @@ -106,7 +107,13 @@ def token_in_blacklist(*args, **kwargs):


def get_csrf_token(encoded_token):
token = decode_jwt(encoded_token, config.decode_key, config.algorithm, csrf=True)
token = decode_jwt(
encoded_token,
config.decode_key,
config.algorithm,
csrf=True,
identity_claim=config.identity_claim
)
return token['csrf']


Expand Down
11 changes: 9 additions & 2 deletions flask_jwt_extended/view_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,13 @@ def _decode_jwt_from_headers():
raise InvalidHeaderError(msg)
token = parts[1]

return decode_jwt(token, config.decode_key, config.algorithm, csrf=False)
return decode_jwt(
encoded_token=token,
secret=config.decode_key,
algorithm=config.algorithm,
csrf=False,
identity_claim=config.identity_claim
)


def _decode_jwt_from_cookies(request_type):
Expand All @@ -163,7 +169,8 @@ def _decode_jwt_from_cookies(request_type):
encoded_token=encoded_token,
secret=config.decode_key,
algorithm=config.algorithm,
csrf=config.csrf_protect
csrf=config.csrf_protect,
identity_claim=config.identity_claim
)

# Verify csrf double submit tokens match if required
Expand Down
6 changes: 6 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def test_default_configs(self):
self.assertEqual(config.decode_key, self.app.secret_key)
self.assertEqual(config.cookie_max_age, None)

self.assertEqual(config.identity_claim, 'identity')

def test_override_configs(self):
self.app.config['JWT_TOKEN_LOCATION'] = ['cookies']
self.app.config['JWT_HEADER_NAME'] = 'TestHeader'
Expand Down Expand Up @@ -86,6 +88,8 @@ def test_override_configs(self):

self.app.secret_key = 'banana'

self.app.config['JWT_IDENTITY_CLAIM'] = 'foo'

with self.app.test_request_context():
self.assertEqual(config.token_location, ['cookies'])
self.assertEqual(config.jwt_in_cookies, True)
Expand Down Expand Up @@ -122,6 +126,8 @@ def test_override_configs(self):

self.assertEqual(config.cookie_max_age, 2147483647)

self.assertEqual(config.identity_claim, 'foo')

def test_invalid_config_options(self):
with self.app.test_request_context():
self.app.config['JWT_TOKEN_LOCATION'] = 'banana'
Expand Down
41 changes: 27 additions & 14 deletions tests/test_jwt_encode_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def test_decode_jwt(self):
'user_claims': {'foo': 'bar'},
}
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
data = decode_jwt(encoded_token, 'secret', 'HS256', csrf=False)
data = decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity')
self.assertIn('exp', data)
self.assertIn('iat', data)
self.assertIn('nbf', data)
Expand Down Expand Up @@ -188,7 +188,7 @@ def test_decode_jwt(self):
'type': 'refresh',
}
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
data = decode_jwt(encoded_token, 'secret', 'HS256', csrf=False)
data = decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity')
self.assertIn('exp', data)
self.assertIn('iat', data)
self.assertIn('nbf', data)
Expand All @@ -210,7 +210,7 @@ def test_decode_invalid_jwt(self):
'exp': datetime.utcnow() - timedelta(minutes=5),
}
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False)
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity')

# Missing jti
with self.assertRaises(JWTDecodeError):
Expand All @@ -220,7 +220,7 @@ def test_decode_invalid_jwt(self):
'type': 'refresh'
}
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False)
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity')

# Missing identity
with self.assertRaises(JWTDecodeError):
Expand All @@ -230,7 +230,17 @@ def test_decode_invalid_jwt(self):
'type': 'refresh'
}
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False)
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity')

# Non-matching identity claim
with self.assertRaises(JWTDecodeError):
token_data = {
'exp': datetime.utcnow() + timedelta(minutes=5),
'identity': 'banana',
'type': 'refresh'
}
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='sub')

# Missing type
with self.assertRaises(JWTDecodeError):
Expand All @@ -240,7 +250,7 @@ def test_decode_invalid_jwt(self):
'exp': datetime.utcnow() + timedelta(minutes=5),
}
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False)
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity')

# Missing fresh in access token
with self.assertRaises(JWTDecodeError):
Expand All @@ -252,7 +262,7 @@ def test_decode_invalid_jwt(self):
'user_claims': {}
}
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False)
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity')

# Missing user claims in access token
with self.assertRaises(JWTDecodeError):
Expand All @@ -264,7 +274,7 @@ def test_decode_invalid_jwt(self):
'fresh': True
}
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False)
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity')

# Bad token type
with self.assertRaises(JWTDecodeError):
Expand All @@ -277,7 +287,7 @@ def test_decode_invalid_jwt(self):
'user_claims': 'banana'
}
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False)
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity')

# Missing csrf in csrf enabled token
with self.assertRaises(JWTDecodeError):
Expand All @@ -290,7 +300,7 @@ def test_decode_invalid_jwt(self):
'user_claims': 'banana'
}
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
decode_jwt(encoded_token, 'secret', 'HS256', csrf=True)
decode_jwt(encoded_token, 'secret', 'HS256', csrf=True, identity_claim='identity')

def test_create_jwt_with_object(self):
# Complex object to test building a JWT from. Normally if you are using
Expand Down Expand Up @@ -322,12 +332,15 @@ def user_identity_lookup(user):
user = TestUser(username='foo', roles=['bar', 'baz'])
access_token = create_access_token(identity=user)
refresh_token = create_refresh_token(identity=user)
identity = 'identity'

# Decode the tokens and make sure the values are set properly
access_token_data = decode_jwt(access_token, app.secret_key,
app.config['JWT_ALGORITHM'], csrf=False)
app.config['JWT_ALGORITHM'], csrf=False,
identity_claim=identity)
refresh_token_data = decode_jwt(refresh_token, app.secret_key,
app.config['JWT_ALGORITHM'], csrf=False)
self.assertEqual(access_token_data['identity'], 'foo')
app.config['JWT_ALGORITHM'], csrf=False,
identity_claim=identity)
self.assertEqual(access_token_data[identity], 'foo')
self.assertEqual(access_token_data['user_claims']['roles'], ['bar', 'baz'])
self.assertEqual(refresh_token_data['identity'], 'foo')
self.assertEqual(refresh_token_data[identity], 'foo')

0 comments on commit df68d65

Please sign in to comment.