Skip to content

Commit

Permalink
Add JWT_DECODE_ISSUER option
Browse files Browse the repository at this point in the history
Closes #259
  • Loading branch information
vimalloc committed Aug 3, 2019
1 parent f39a679 commit 05a802a
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 6 deletions.
4 changes: 4 additions & 0 deletions flask_jwt_extended/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,10 @@ def json_encoder(self):
def audience(self):
return current_app.config['JWT_DECODE_AUDIENCE']

@property
def issuer(self):
return current_app.config['JWT_DECODE_ISSUER']

@property
def leeway(self):
return current_app.config['JWT_DECODE_LEEWAY']
Expand Down
10 changes: 9 additions & 1 deletion flask_jwt_extended/jwt_manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import datetime
from warnings import warn

from jwt import ExpiredSignatureError, InvalidTokenError, InvalidAudienceError
from jwt import (
ExpiredSignatureError, InvalidTokenError, InvalidAudienceError,
InvalidIssuerError
)
try:
from flask import _app_ctx_stack as ctx_stack
except ImportError: # pragma: no cover
Expand Down Expand Up @@ -126,6 +129,10 @@ def handle_wrong_token_error(e):
def handle_invalid_audience_error(e):
return self._invalid_token_callback(str(e))

@app.errorhandler(InvalidIssuerError)
def handle_invalid_issuer_error(e):
return self._invalid_token_callback(str(e))

@app.errorhandler(RevokedTokenError)
def handle_revoked_token_error(e):
return self._revoked_token_callback()
Expand Down Expand Up @@ -214,6 +221,7 @@ def _set_default_configuration_options(app):
app.config.setdefault('JWT_IDENTITY_CLAIM', 'identity')
app.config.setdefault('JWT_USER_CLAIMS', 'user_claims')
app.config.setdefault('JWT_DECODE_AUDIENCE', None)
app.config.setdefault('JWT_DECODE_ISSUER', None)
app.config.setdefault('JWT_DECODE_LEEWAY', 0)

app.config.setdefault('JWT_CLAIMS_IN_REFRESH_TOKEN', False)
Expand Down
5 changes: 3 additions & 2 deletions flask_jwt_extended/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def encode_refresh_token(identity, secret, algorithm, expires_delta, user_claims

def decode_jwt(encoded_token, secret, algorithms, identity_claim_key,
user_claims_key, csrf_value=None, audience=None,
leeway=0, allow_expired=False):
leeway=0, allow_expired=False, issuer=None):
"""
Decodes an encoded JWT
Expand All @@ -125,6 +125,7 @@ def decode_jwt(encoded_token, secret, algorithms, identity_claim_key,
:param user_claims_key: expected key that contains the user claims
:param csrf_value: Expected double submit csrf value
:param audience: expected audience in the JWT
:param issuer: expected issuer in the JWT
:param leeway: optional leeway to add some margin around expiration times
:param allow_expired: Options to ignore exp claim validation in token
:return: Dictionary containing contents of the JWT
Expand All @@ -135,7 +136,7 @@ def decode_jwt(encoded_token, secret, algorithms, identity_claim_key,

# This call verifies the ext, iat, nbf, and aud claims
data = jwt.decode(encoded_token, secret, algorithms=algorithms, audience=audience,
leeway=leeway, options=options)
leeway=leeway, options=options, issuer=issuer)

# Make sure that any custom claims we expect in the token are present
if 'jti' not in data:
Expand Down
2 changes: 2 additions & 0 deletions flask_jwt_extended/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def decode_token(encoded_token, csrf_value=None, allow_expired=False):
user_claims_key=config.user_claims_key,
csrf_value=csrf_value,
audience=config.audience,
issuer=config.issuer,
leeway=config.leeway,
allow_expired=allow_expired
)
Expand All @@ -115,6 +116,7 @@ def decode_token(encoded_token, csrf_value=None, allow_expired=False):
user_claims_key=config.user_claims_key,
csrf_value=csrf_value,
audience=config.audience,
issuer=config.issuer,
leeway=config.leeway,
allow_expired=True
)
Expand Down
24 changes: 21 additions & 3 deletions tests/test_decode_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from jwt import (
ExpiredSignatureError, InvalidSignatureError, InvalidAudienceError,
ImmatureSignatureError
ImmatureSignatureError, InvalidIssuerError
)

from flask_jwt_extended import (
Expand Down Expand Up @@ -246,9 +246,9 @@ def test_valid_aud(app, default_access_token, token_aud):
app.config['JWT_DECODE_AUDIENCE'] = ['foo', 'bar']

default_access_token['aud'] = token_aud
invalid_token = encode_token(app, default_access_token)
valid_token = encode_token(app, default_access_token)
with app.test_request_context():
decoded = decode_token(invalid_token)
decoded = decode_token(valid_token)
assert decoded['aud'] == token_aud


Expand All @@ -261,3 +261,21 @@ def test_invalid_aud(app, default_access_token, token_aud):
with pytest.raises(InvalidAudienceError):
with app.test_request_context():
decode_token(invalid_token)

def test_valid_iss(app, default_access_token):
app.config['JWT_DECODE_ISSUER'] = 'foobar'

default_access_token['iss'] = 'foobar'
valid_token = encode_token(app, default_access_token)
with app.test_request_context():
decoded = decode_token(valid_token)
assert decoded['iss'] == 'foobar'

def test_invalid_iss(app, default_access_token):
app.config['JWT_DECODE_ISSUER'] = 'baz'

default_access_token['iss'] = 'foobar'
invalid_token = encode_token(app, default_access_token)
with pytest.raises(InvalidIssuerError):
with app.test_request_context():
decode_token(invalid_token)
23 changes: 23 additions & 0 deletions tests/test_view_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,29 @@ def test_jwt_invalid_audience(app):
assert response.status_code == 422
assert response.get_json() == {'msg': 'Invalid audience'}

def test_jwt_invalid_issuer(app):
url = '/protected'
jwtM = get_jwt_manager(app)
test_client = app.test_client()

# No issuer claim expected or provided - OK
access_token = encode_token(app, {'identity': 'me'})
response = test_client.get(url, headers=make_headers(access_token))
assert response.status_code == 200

# Issuer claim expected and not provided - not OK
app.config['JWT_DECODE_ISSUER'] = 'my_issuer'
access_token = encode_token(app, {'identity': 'me'})
response = test_client.get(url, headers=make_headers(access_token))
assert response.status_code == 422
assert response.get_json() == {'msg': 'Token is missing the "iss" claim'}

# Issuer claim still expected and wrong one provided - not OK
access_token = encode_token(app, {'iss': 'different_issuer', 'identity': 'me'})
response = test_client.get(url, headers=make_headers(access_token))
assert response.status_code == 422
assert response.get_json() == {'msg': 'Invalid issuer'}


@pytest.mark.parametrize("delta_func", [timedelta, relativedelta])
def test_expired_token(app, delta_func):
Expand Down

0 comments on commit 05a802a

Please sign in to comment.