Skip to content

Commit

Permalink
#654 added oauth token refresh support
Browse files Browse the repository at this point in the history
  • Loading branch information
bugy committed Jun 27, 2023
1 parent f0fc7cf commit 9e3abfe
Show file tree
Hide file tree
Showing 20 changed files with 995 additions and 133 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ or [how to configure the server](https://github.com/bugy/script-server/wiki/Serv

### Server-side

Python 3.6 or higher with the following modules:
Python 3.7 or higher with the following modules:

* Tornado 5 / 6

Expand Down
123 changes: 78 additions & 45 deletions src/auth/auth_abstract_oauth.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import abc
import asyncio
import datetime
import json
import logging
import os
Expand All @@ -15,11 +16,12 @@

from auth import auth_base
from auth.auth_base import AuthFailureError, AuthBadRequestException, AuthRejectedError
from auth.oauth_token_manager import OAuthTokenManager
from auth.oauth_token_response import OAuthTokenResponse
from model import model_helper
from model.model_helper import read_bool_from_config, read_int_from_config
from model.server_conf import InvalidServerConfigException
from utils import file_utils
from utils.tornado_utils import get_secure_cookie

LOGGER = logging.getLogger('script_server.AbstractOauthAuthenticator')

Expand Down Expand Up @@ -90,6 +92,12 @@ def __init__(self, oauth_authorize_url, oauth_token_url, oauth_scope, params_dic

self._schedule_dump_task()

self._token_manager = OAuthTokenManager(
enabled=bool(self.auth_info_ttl),
fetch_token_callback=self._fetch_token_by_refresh)

self.ioloop = tornado.ioloop.IOLoop.current()

@staticmethod
def _validate_dump_file(dump_file):
if os.path.isdir(dump_file):
Expand All @@ -105,8 +113,8 @@ async def authenticate(self, request_handler):
LOGGER.error('Code is not specified')
raise AuthBadRequestException('Missing authorization information. Please contact your administrator')

(access_token, refresh_token) = await self.fetch_access_token(code, request_handler)
user_info = await self.fetch_user_info(access_token)
token_response = await self.fetch_access_token_by_code(code, request_handler)
user_info = await self.fetch_user_info(token_response.access_token)

username = user_info.username
if not username:
Expand All @@ -124,12 +132,13 @@ async def authenticate(self, request_handler):
self._users[username] = user_state

if self.group_support:
await self.load_groups(access_token, username, user_info, user_state)
await self.load_groups(token_response.access_token, username, user_info, user_state)

now = time.time()

self._token_manager.update_tokens(token_response, username, request_handler)

if self.auth_info_ttl:
request_handler.set_secure_cookie('token', access_token)
user_state.last_auth_update = now

user_state.last_visit = now
Expand All @@ -144,23 +153,28 @@ async def load_groups(self, access_token, username, user_info, user_state):
user_state.groups = user_groups
LOGGER.info('Loaded groups for ' + username + ': ' + str(user_state.groups))

def validate_user(self, user, request_handler):
async def validate_user(self, user, request_handler):
if not user:
LOGGER.warning('Username is not available')
return False

now = time.time()

user_state = self._users.get(user)
validate_expiration = True
if not user_state:
# if nothing is enabled, it's ok not to have user state (e.g. after server restart)
if self.session_expire <= 0 and not self.auth_info_ttl and not self.group_support:
return True
elif self._token_manager.can_restore_state(request_handler):
validate_expiration = False
user_state = _UserState(user)
self._users[user] = user_state
else:
LOGGER.info('User %s state is missing', user)
return False

if self.session_expire > 0:
if (self.session_expire > 0) and validate_expiration:
last_visit = user_state.last_visit
if (last_visit is None) or ((last_visit + self.session_expire) < now):
LOGGER.info('User %s state is expired', user)
Expand All @@ -169,9 +183,10 @@ def validate_user(self, user, request_handler):
user_state.last_visit = now

if self.auth_info_ttl:
access_token = get_secure_cookie(request_handler, 'token')
access_token = await self._token_manager.synchronize_user_tokens(user, request_handler)
if access_token is None:
LOGGER.info('User %s token is not available', user)
self._remove_user(user)
return False

self.update_user_auth(user, user_state, access_token)
Expand All @@ -186,57 +201,40 @@ def get_groups(self, user, known_groups=None):
return user_state.groups

def logout(self, user, request_handler):
request_handler.clear_cookie('token')
self._token_manager.logout(user, request_handler)
self._remove_user(user)

self._dump_state()

def _remove_user(self, user):
if user in self._users:
del self._users[user]
self._token_manager.remove_user(user)

async def fetch_access_token(self, code, request_handler):
body = urllib_parse.urlencode({
async def fetch_access_token_by_code(self, code, request_handler):
return await self._fetch_token({
'redirect_uri': get_path_for_redirect(request_handler),
'code': code,
'client_id': self.client_id,
'client_secret': self.secret,
'grant_type': 'authorization_code',
})

response = await self.http_client.fetch(
self.oauth_token_url,
method='POST',
headers={'Content-Type': 'application/x-www-form-urlencoded'},
body=body,
raise_error=False)

response_values = {}
if response.body:
response_values = escape.json_decode(response.body)

if response.error:
if response_values.get('error_description'):
error_text = response_values.get('error_description')
elif response_values.get('error'):
error_text = response_values.get('error')
else:
error_text = str(response.error)

error_message = 'Failed to load access_token: ' + error_text
LOGGER.error(error_message)
raise AuthFailureError(error_message)

response_values = escape.json_decode(response.body)
access_token = response_values.get('access_token')
refresh_token = response_values.get('refresh_token')

if not access_token:
message = 'No access token in response: ' + str(response.body)
LOGGER.error(message)
raise AuthFailureError(message)

return access_token, refresh_token
async def _fetch_token_by_refresh(self, refresh_token, username):
if username not in self._users:
return None

try:
return await self._fetch_token({
'refresh_token': refresh_token,
'client_id': self.client_id,
'client_secret': self.secret,
'grant_type': 'refresh_token',
})
except AuthFailureError:
LOGGER.info(f'Failed to refresh token for user {username}. Logging out')
self._remove_user(username)
return None

def update_user_auth(self, username, user_state, access_token):
now = time.time()
Expand All @@ -246,7 +244,7 @@ def update_user_auth(self, username, user_state, access_token):
if not ttl_expired:
return

tornado.ioloop.IOLoop.current().spawn_callback(
self.ioloop.spawn_callback(
self._do_update_user_auth_async,
username,
user_state,
Expand Down Expand Up @@ -342,6 +340,41 @@ def _cleanup(self):
if self.timer:
self.timer.cancel()

async def _fetch_token(self, body):
encoded_body = urllib_parse.urlencode(body)

response = await self.http_client.fetch(
self.oauth_token_url,
method='POST',
headers={'Content-Type': 'application/x-www-form-urlencoded'},
body=encoded_body,
raise_error=False)

response_values = {}
if response.body:
response_values = escape.json_decode(response.body)

if response.error:
if response_values.get('error_description'):
error_text = response_values.get('error_description')
elif response_values.get('error'):
error_text = response_values.get('error')
else:
error_text = str(response.error)

error_message = 'Failed to refresh access_token: ' + error_text
LOGGER.error(error_message)
raise AuthFailureError(error_message)

token_response = OAuthTokenResponse.create(response_values, datetime.datetime.now())

if not token_response.access_token:
message = 'No access token in response: ' + str(response.body)
LOGGER.error(message)
raise AuthFailureError(message)

return token_response


def get_path_for_redirect(request_handler):
referer = request_handler.request.headers.get('Referer')
Expand Down
2 changes: 1 addition & 1 deletion src/auth/auth_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def get_client_visible_config(self):
def get_groups(self, user, known_groups=None):
return []

def validate_user(self, user, request_handler):
async def validate_user(self, user, request_handler):
return True

def perform_basic_auth(self, user, password):
Expand Down
2 changes: 1 addition & 1 deletion src/auth/auth_gitlab.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, params_dict):
async def fetch_user_info(self, access_token) -> _OauthUserInfo:
user = await self.oauth2_request(
_OAUTH_GITLAB_USERINFO % self.gitlab_host,
access_token)
access_token=access_token)
if user is None:
return None

Expand Down
5 changes: 2 additions & 3 deletions src/auth/identification.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
import logging
import uuid

import tornado.websocket

from model.trusted_ips import TrustedIpValidator
from utils import tornado_utils, date_utils, audit_utils
from utils.date_utils import days_to_ms
from utils.tornado_utils import can_write_secure_cookie

LOGGER = logging.getLogger('identification')

Expand Down Expand Up @@ -120,4 +119,4 @@ def _write_client_token(self, client_id, request_handler):
request_handler.set_secure_cookie(self.COOKIE_KEY, new_token, expires_days=self.EXPIRES_DAYS)

def _can_write(self, request_handler):
return not isinstance(request_handler, tornado.websocket.WebSocketHandler)
return can_write_secure_cookie(request_handler)
Loading

0 comments on commit 9e3abfe

Please sign in to comment.