Skip to content

Commit

Permalink
Use JWT_IDENTITY_CLAIM for encoding too
Browse files Browse the repository at this point in the history
  • Loading branch information
psafont committed Jul 12, 2017
1 parent df68d65 commit f8d83f2
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 43 deletions.
2 changes: 1 addition & 1 deletion docs/options.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ General Options:
such as ``RS*`` or ``ES*``. PEM format expected.
``JWT_PRIVATE_KEY`` The private key needed for asymmetric based signing algorithms,
such as ``RS*`` or ``ES*``. PEM format expected.
``JWT_IDENTITY_CLAIM`` Claim in the tokens that is used on decoding as source of identity.
``JWT_IDENTITY_CLAIM`` Claim in the tokens that is used as source of identity.
For interoperativity, the JWT RFC recommends using ``'sub'``.
Defaults to ``'identity'``.
================================= =========================================
Expand Down
6 changes: 4 additions & 2 deletions flask_jwt_extended/jwt_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,8 @@ def create_refresh_token(self, identity, expires_delta=None):
secret=config.encode_key,
algorithm=config.algorithm,
expires_delta=expires_delta,
csrf=config.csrf_protect
csrf=config.csrf_protect,
identity_claim=config.identity_claim
)
return refresh_token

Expand Down Expand Up @@ -354,7 +355,8 @@ def create_access_token(self, identity, fresh=False, expires_delta=None):
expires_delta=expires_delta,
fresh=fresh,
user_claims=self._user_claims_callback(identity),
csrf=config.csrf_protect
csrf=config.csrf_protect,
identity_claim=config.identity_claim
)
return access_token

10 changes: 6 additions & 4 deletions flask_jwt_extended/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def _encode_jwt(additional_token_data, expires_delta, secret, algorithm):


def encode_access_token(identity, secret, algorithm, expires_delta, fresh,
user_claims, csrf):
user_claims, csrf, identity_claim):
"""
Creates a new encoded (utf-8) access token.
Expand All @@ -40,11 +40,12 @@ def encode_access_token(identity, secret, algorithm, expires_delta, fresh,
be json serializable
:param csrf: Whether to include a csrf double submit claim in this token
(boolean)
:param identity_claim: Which claim should be used to store the identity in
:return: Encoded access token
"""
# Create the jwt
token_data = {
'identity': identity,
identity_claim: identity,
'fresh': fresh,
'type': 'access',
'user_claims': user_claims,
Expand All @@ -54,7 +55,7 @@ def encode_access_token(identity, secret, algorithm, expires_delta, fresh,
return _encode_jwt(token_data, expires_delta, secret, algorithm)


def encode_refresh_token(identity, secret, algorithm, expires_delta, csrf):
def encode_refresh_token(identity, secret, algorithm, expires_delta, csrf, identity_claim):
"""
Creates a new encoded (utf-8) refresh token.
Expand All @@ -65,10 +66,11 @@ def encode_refresh_token(identity, secret, algorithm, expires_delta, csrf):
(datetime.timedelta)
:param csrf: Whether to include a csrf double submit claim in this token
(boolean)
:param identity_claim: Which claim should be used to store the identity in
:return: Encoded refresh token
"""
token_data = {
'identity': identity,
identity_claim: identity,
'type': 'refresh',
}
if csrf:
Expand Down
82 changes: 50 additions & 32 deletions tests/test_jwt_encode_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def test_encode_access_token(self):
with self.app.test_request_context():
identity = 'user1'
token = encode_access_token(identity, secret, algorithm, token_expire_delta,
fresh=True, user_claims=user_claims, csrf=False)
fresh=True, user_claims=user_claims, csrf=False,
identity_claim='identity')
data = jwt.decode(token, secret, algorithms=[algorithm])
self.assertIn('exp', data)
self.assertIn('iat', data)
Expand All @@ -59,7 +60,8 @@ def test_encode_access_token(self):
# Check with a non-fresh token
identity = 12345 # identity can be anything json serializable
token = encode_access_token(identity, secret, algorithm, token_expire_delta,
fresh=False, user_claims=user_claims, csrf=True)
fresh=False, user_claims=user_claims, csrf=True,
identity_claim='identity')
data = jwt.decode(token, secret, algorithms=[algorithm])
self.assertIn('exp', data)
self.assertIn('iat', data)
Expand Down Expand Up @@ -87,33 +89,35 @@ def test_encode_invalid_access_token(self):
with self.assertRaises(Exception):
encode_access_token('user1', 'secret', 'HS256',
timedelta(hours=1), True, user_claims,
csrf=True)
csrf=True, identity_claim='identity')

user_claims = {'foo': timedelta(hours=4)}
with self.assertRaises(Exception):
encode_access_token('user1', 'secret', 'HS256',
timedelta(hours=1), True, user_claims,
csrf=True)
csrf=True, identity_claim='identity')

def test_encode_refresh_token(self):
secret = 'super-totally-secret-key'
algorithm = 'HS256'
token_expire_delta = timedelta(minutes=5)
identity_claim = 'sub'

# Check with a fresh token
with self.app.test_request_context():
identity = 'user1'
token = encode_refresh_token(identity, secret, algorithm,
token_expire_delta, csrf=False)
token_expire_delta, csrf=False,
identity_claim=identity_claim)
data = jwt.decode(token, secret, algorithms=[algorithm])
self.assertIn('exp', data)
self.assertIn('iat', data)
self.assertIn('nbf', data)
self.assertIn('jti', data)
self.assertIn('type', data)
self.assertIn('identity', data)
self.assertIn(identity_claim, data)
self.assertNotIn('csrf', data)
self.assertEqual(data['identity'], identity)
self.assertEqual(data[identity_claim], identity)
self.assertEqual(data['type'], 'refresh')
self.assertEqual(data['iat'], data['nbf'])
now_ts = calendar.timegm(datetime.utcnow().utctimetuple())
Expand All @@ -124,16 +128,17 @@ def test_encode_refresh_token(self):
# Check with a csrf token
identity = 12345 # identity can be anything json serializable
token = encode_refresh_token(identity, secret, algorithm,
token_expire_delta, csrf=True)
token_expire_delta, csrf=True,
identity_claim=identity_claim)
data = jwt.decode(token, secret, algorithms=[algorithm])
self.assertIn('exp', data)
self.assertIn('iat', data)
self.assertIn('nbf', data)
self.assertIn('jti', data)
self.assertIn('type', data)
self.assertIn('csrf', data)
self.assertIn('identity', data)
self.assertEqual(data['identity'], identity)
self.assertIn(identity_claim, data)
self.assertEqual(data[identity_claim], identity)
self.assertEqual(data['type'], 'refresh')
self.assertEqual(data['iat'], data['nbf'])
now_ts = calendar.timegm(datetime.utcnow().utctimetuple())
Expand All @@ -142,6 +147,7 @@ def test_encode_refresh_token(self):
self.assertGreater(exp_seconds, 60 * 4)

def test_decode_jwt(self):
identity_claim = 'sub'
# Test decoding a valid access token
with self.app.test_request_context():
now = datetime.utcnow()
Expand All @@ -151,26 +157,27 @@ def test_decode_jwt(self):
'iat': now,
'nbf': now,
'jti': 'banana',
'identity': 'banana',
identity_claim: 'banana',
'fresh': True,
'type': 'access',
'user_claims': {'foo': 'bar'},
}
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
data = decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity')
data = decode_jwt(encoded_token, 'secret', 'HS256',
csrf=False, identity_claim=identity_claim)
self.assertIn('exp', data)
self.assertIn('iat', data)
self.assertIn('nbf', data)
self.assertIn('jti', data)
self.assertIn('identity', data)
self.assertIn(identity_claim, data)
self.assertIn('fresh', data)
self.assertIn('type', data)
self.assertIn('user_claims', data)
self.assertEqual(data['exp'], now_ts + (5 * 60))
self.assertEqual(data['iat'], now_ts)
self.assertEqual(data['nbf'], now_ts)
self.assertEqual(data['jti'], 'banana')
self.assertEqual(data['identity'], 'banana')
self.assertEqual(data[identity_claim], 'banana')
self.assertEqual(data['fresh'], True)
self.assertEqual(data['type'], 'access')
self.assertEqual(data['user_claims'], {'foo': 'bar'})
Expand All @@ -184,22 +191,23 @@ def test_decode_jwt(self):
'iat': now,
'nbf': now,
'jti': 'banana',
'identity': 'banana',
identity_claim: 'banana',
'type': 'refresh',
}
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
data = decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity')
data = decode_jwt(encoded_token, 'secret', 'HS256',
csrf=False, identity_claim=identity_claim)
self.assertIn('exp', data)
self.assertIn('iat', data)
self.assertIn('nbf', data)
self.assertIn('jti', data)
self.assertIn('identity', data)
self.assertIn(identity_claim, data)
self.assertIn('type', data)
self.assertEqual(data['exp'], now_ts + (5 * 60))
self.assertEqual(data['iat'], now_ts)
self.assertEqual(data['nbf'], now_ts)
self.assertEqual(data['jti'], 'banana')
self.assertEqual(data['identity'], 'banana')
self.assertEqual(data[identity_claim], 'banana')
self.assertEqual(data['type'], 'refresh')

def test_decode_invalid_jwt(self):
Expand All @@ -210,7 +218,8 @@ def test_decode_invalid_jwt(self):
'exp': datetime.utcnow() - timedelta(minutes=5),
}
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity')
decode_jwt(encoded_token, 'secret', 'HS256',
csrf=False, identity_claim='identity')

# Missing jti
with self.assertRaises(JWTDecodeError):
Expand All @@ -220,7 +229,8 @@ def test_decode_invalid_jwt(self):
'type': 'refresh'
}
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity')
decode_jwt(encoded_token, 'secret', 'HS256',
csrf=False, identity_claim='identity')

# Missing identity
with self.assertRaises(JWTDecodeError):
Expand All @@ -230,7 +240,8 @@ def test_decode_invalid_jwt(self):
'type': 'refresh'
}
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity')
decode_jwt(encoded_token, 'secret', 'HS256',
csrf=False, identity_claim='identity')

# Non-matching identity claim
with self.assertRaises(JWTDecodeError):
Expand All @@ -240,7 +251,8 @@ def test_decode_invalid_jwt(self):
'type': 'refresh'
}
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='sub')
decode_jwt(encoded_token, 'secret', 'HS256',
csrf=False, identity_claim='sub')

# Missing type
with self.assertRaises(JWTDecodeError):
Expand All @@ -250,7 +262,8 @@ def test_decode_invalid_jwt(self):
'exp': datetime.utcnow() + timedelta(minutes=5),
}
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity')
decode_jwt(encoded_token, 'secret', 'HS256',
csrf=False, identity_claim='identity')

# Missing fresh in access token
with self.assertRaises(JWTDecodeError):
Expand All @@ -262,7 +275,8 @@ def test_decode_invalid_jwt(self):
'user_claims': {}
}
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity')
decode_jwt(encoded_token, 'secret', 'HS256',
csrf=False, identity_claim='identity')

# Missing user claims in access token
with self.assertRaises(JWTDecodeError):
Expand All @@ -274,7 +288,8 @@ def test_decode_invalid_jwt(self):
'fresh': True
}
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity')
decode_jwt(encoded_token, 'secret', 'HS256',
csrf=False, identity_claim='identity')

# Bad token type
with self.assertRaises(JWTDecodeError):
Expand All @@ -287,7 +302,8 @@ def test_decode_invalid_jwt(self):
'user_claims': 'banana'
}
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
decode_jwt(encoded_token, 'secret', 'HS256', csrf=False, identity_claim='identity')
decode_jwt(encoded_token, 'secret', 'HS256',
csrf=False, identity_claim='identity')

# Missing csrf in csrf enabled token
with self.assertRaises(JWTDecodeError):
Expand All @@ -300,7 +316,8 @@ def test_decode_invalid_jwt(self):
'user_claims': 'banana'
}
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
decode_jwt(encoded_token, 'secret', 'HS256', csrf=True, identity_claim='identity')
decode_jwt(encoded_token, 'secret', 'HS256', csrf=True,
identity_claim='identity')

def test_create_jwt_with_object(self):
# Complex object to test building a JWT from. Normally if you are using
Expand Down Expand Up @@ -329,18 +346,19 @@ def user_identity_lookup(user):

# Create the token using the complex object
with app.test_request_context():
identity_claim = 'sub'
app.config['JWT_IDENTITY_CLAIM'] = identity_claim
user = TestUser(username='foo', roles=['bar', 'baz'])
access_token = create_access_token(identity=user)
refresh_token = create_refresh_token(identity=user)
identity = 'identity'

# Decode the tokens and make sure the values are set properly
access_token_data = decode_jwt(access_token, app.secret_key,
app.config['JWT_ALGORITHM'], csrf=False,
identity_claim=identity)
identity_claim=identity_claim)
refresh_token_data = decode_jwt(refresh_token, app.secret_key,
app.config['JWT_ALGORITHM'], csrf=False,
identity_claim=identity)
self.assertEqual(access_token_data[identity], 'foo')
identity_claim=identity_claim)
self.assertEqual(access_token_data[identity_claim], 'foo')
self.assertEqual(access_token_data['user_claims']['roles'], ['bar', 'baz'])
self.assertEqual(refresh_token_data[identity], 'foo')
self.assertEqual(refresh_token_data[identity_claim], 'foo')
11 changes: 7 additions & 4 deletions tests/test_protected_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,8 @@ def test_bad_tokens(self):
# Test token that was signed with a different key
with self.app.test_request_context():
token = encode_access_token('foo', 'newsecret', 'HS256',
timedelta(minutes=5), True, {}, csrf=False)
timedelta(minutes=5), True, {}, csrf=False,
identity_claim='identity')
auth_header = "Bearer {}".format(token)
response = self.client.get('/protected', headers={'Authorization': auth_header})
data = json.loads(response.get_data(as_text=True))
Expand Down Expand Up @@ -397,7 +398,7 @@ def test_optional_jwt_bad_tokens(self):
with self.app.test_request_context():
token = encode_access_token('foo', 'newsecret', 'HS256',
timedelta(minutes=5), True, {},
csrf=False)
csrf=False, identity_claim='identity')
auth_header = "Bearer {}".format(token)
response = self.client.get('/partially-protected',
headers={'Authorization': auth_header})
Expand Down Expand Up @@ -584,7 +585,8 @@ def test_jwt_with_different_algorithm(self):
expires_delta=timedelta(minutes=5),
fresh=True,
user_claims={},
csrf=False
csrf=False,
identity_claim='identity'
)
status, data = self._jwt_get('/protected', access_token)
self.assertEqual(status, 422)
Expand All @@ -600,7 +602,8 @@ def test_optional_jwt_with_different_algorithm(self):
expires_delta=timedelta(minutes=5),
fresh=True,
user_claims={},
csrf=False
csrf=False,
identity_claim='identity'
)
status, data = self._jwt_get('/partially-protected', access_token)
self.assertEqual(status, 422)
Expand Down

0 comments on commit f8d83f2

Please sign in to comment.