diff --git a/oauth2_provider_jwt/authentication.py b/oauth2_provider_jwt/authentication.py index 76feb52..91af25a 100644 --- a/oauth2_provider_jwt/authentication.py +++ b/oauth2_provider_jwt/authentication.py @@ -11,6 +11,47 @@ from .utils import decode_jwt +class JwtToken(dict): + """ + Mimics the structure of `AbstractAccessToken` so you can use standard + Django Oauth Toolkit permissions like `TokenHasScope`. + """ + def __init__(self, payload): + super(JwtToken, self).__init__(**payload) + + def __getattr__(self, item): + return self[item] + + def is_valid(self, scopes=None): + """ + Checks if the access token is valid. + + :param scopes: An iterable containing the scopes to check or None + """ + return not self.is_expired() and self.allow_scopes(scopes) + + def is_expired(self): + """ + Check token expiration with timezone awareness + """ + # Token expiration is already checked + return False + + def allow_scopes(self, scopes): + """ + Check if the token allows the provided scopes + + :param scopes: An iterable containing the scopes to check + """ + if not scopes: + return True + + provided_scopes = set(self.scope.split()) + resource_scopes = set(scopes) + + return resource_scopes.issubset(provided_scopes) + + class JWTAuthentication(BaseAuthentication): """ Token based authentication using the JSON Web Token standard. @@ -46,7 +87,7 @@ def authenticate(self, request): self._add_session_details(request, payload) user = self.authenticate_credentials(payload) - return user, payload + return user, JwtToken(payload) def authenticate_credentials(self, payload): """ diff --git a/oauth2_provider_jwt/views.py b/oauth2_provider_jwt/views.py index 464d094..5c7d669 100644 --- a/oauth2_provider_jwt/views.py +++ b/oauth2_provider_jwt/views.py @@ -12,16 +12,20 @@ class TokenView(views.TokenView): - def _get_access_token_jwt(self, request, expires_in): + def _get_access_token_jwt(self, request, content): extra_data = {} issuer = settings.JWT_ISSUER payload_enricher = getattr(settings, 'JWT_PAYLOAD_ENRICHER', None) if payload_enricher: fn = import_string(payload_enricher) extra_data = fn(request) + + if 'scope' in content: + extra_data['scope'] = content['scope'] + if request.POST.get('username'): extra_data['username'] = request.POST.get('username') - payload = generate_payload(issuer, expires_in, **extra_data) + payload = generate_payload(issuer, content['expires_in'], **extra_data) token = encode_jwt(payload) return token @@ -44,7 +48,7 @@ def post(self, request, *args, **kwargs): 'Missing JWT configuration, skipping token build') else: content['access_token_jwt'] = self._get_access_token_jwt( - request, content['expires_in']) + request, content) try: content = bytes(json.dumps(content), 'utf-8') except TypeError: diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 8906a0c..8aa45ac 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -98,3 +98,35 @@ def test_post_valid_jwt_with_auth(self): HTTP_AUTHORIZATION='JWT {}'.format(jwt_value), content_type='application/json') self.assertEqual(response.status_code, 403) + + def test_post_valid_jwt_with_auth_and_scope_not_valid(self): + now = datetime.utcnow() + payload = { + 'iss': 'issuer', + 'exp': now + timedelta(seconds=100), + 'iat': now, + 'username': 'temporary', + 'scope': 'read', # Incorrect scope + } + jwt_value = utils.encode_jwt(payload) + response = self.client.post( + '/jwt_auth_scope/', {'example': 'example'}, + HTTP_AUTHORIZATION='JWT {}'.format(jwt_value), + content_type='application/json') + self.assertEqual(response.status_code, 403) + + def test_post_valid_jwt_with_auth_and_scope_is_valid(self): + now = datetime.utcnow() + payload = { + 'iss': 'issuer', + 'exp': now + timedelta(seconds=100), + 'iat': now, + 'username': 'temporary', + 'scope': 'write', # Correct scope + } + jwt_value = utils.encode_jwt(payload) + response = self.client.post( + '/jwt_auth_scope/', {'example': 'example'}, + HTTP_AUTHORIZATION='JWT {}'.format(jwt_value), + content_type='application/json') + self.assertEqual(response.status_code, 200) diff --git a/tests/test_views.py b/tests/test_views.py index 2359018..212a7ec 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -1,6 +1,7 @@ import base64 import datetime import json + try: from urllib.parse import urlencode except ImportError: @@ -109,12 +110,14 @@ def test_get_token(self): self.assertEqual(response.status_code, 200) content = json.loads(response.content.decode("utf-8")) + jwt_token = content["access_token_jwt"] self.assertEqual(content["token_type"], "Bearer") - self.assertIn(type(content["access_token_jwt"]).__name__, - ('unicode', 'str')) + self.assertIn(type(jwt_token).__name__, ('unicode', 'str')) self.assertEqual(content["scope"], "read write") self.assertEqual(content["expires_in"], oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS) + self.assertDictContainsSubset({'scope': 'read write'}, + self.decode_jwt(jwt_token)) @patch('oauth2_provider_jwt.views.TokenView._is_jwt_config_set') def test_do_not_get_token_missing_conf(self, mock_is_jwt_config_set): @@ -159,10 +162,26 @@ def test_get_enriched_jwt(self): **auth_headers) content = json.loads(response.content.decode("utf-8")) access_token_jwt = content["access_token_jwt"] - headers, payload, verify_signature = access_token_jwt.split(".") - payload += '=' * (-len(payload) % 4) # add padding - payload_dict = json.loads(base64.b64decode(payload).decode("utf-8")) - self.assertDictContainsSubset({'sub': 'unique-user'}, payload_dict) + self.assertDictContainsSubset({'sub': 'unique-user'}, + self.decode_jwt(access_token_jwt)) + + def test_get_custom_scope_in_jwt(self): + token_request_data = { + "grant_type": "password", + "scope": "read", + "username": "test_user", + "password": "123456", + } + auth_headers = get_basic_auth_header(self.application.client_id, + self.application.client_secret) + + response = self.client.post( + reverse("oauth2_provider_jwt:token"), data=token_request_data, + **auth_headers) + content = json.loads(response.content.decode("utf-8")) + access_token_jwt = content["access_token_jwt"] + self.assertDictContainsSubset({'scope': 'read'}, + self.decode_jwt(access_token_jwt)) def test_refresh_token(self): access_token = AccessToken.objects.create( @@ -190,3 +209,8 @@ def test_refresh_token(self): content = json.loads(response.content.decode("utf-8")) self.assertIn(type(content["access_token_jwt"]).__name__, ('unicode', 'str')) + + def decode_jwt(self, access_token_jwt): + headers, payload, verify_signature = access_token_jwt.split(".") + payload += '=' * (-len(payload) % 4) # add padding + return json.loads(base64.b64decode(payload).decode("utf-8")) diff --git a/tests/urls.py b/tests/urls.py index c0cf6d2..68aa3c6 100644 --- a/tests/urls.py +++ b/tests/urls.py @@ -5,6 +5,7 @@ from django.http import HttpResponse from rest_framework import permissions from rest_framework.views import APIView +from oauth2_provider.contrib.rest_framework import TokenHasScope admin.autodiscover() @@ -31,11 +32,24 @@ def post(self, request): return HttpResponse(response) +class MockForAuthScopeView(APIView): + permission_classes = (TokenHasScope,) + required_scopes = ['write'] + + def get(self, _request): + return HttpResponse('mockforauthscopeview-get') + + def post(self, request): + response = json.dumps({"username": request.user.username}) + return HttpResponse(response) + + urlpatterns = [ url(r"^o/", include("oauth2_provider_jwt.urls", namespace="oauth2_provider_jwt")), url(r'^jwt/$', MockView.as_view()), url(r'^jwt_auth/$', MockForAuthView.as_view()), + url(r'^jwt_auth_scope/$', MockForAuthScopeView.as_view()), ]