From ec562671b18d6674aa944cdb90f522876056a724 Mon Sep 17 00:00:00 2001 From: Yoann Moranville Date: Wed, 11 Jan 2023 15:38:11 +0100 Subject: [PATCH 1/8] Fix problem when not using pkce --- dash_auth_external/auth.py | 5 +++-- dash_auth_external/routes.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/dash_auth_external/auth.py b/dash_auth_external/auth.py index d7b1500..28b3115 100644 --- a/dash_auth_external/auth.py +++ b/dash_auth_external/auth.py @@ -41,7 +41,7 @@ def __init__( home_suffix="/home", _token_field_name: str = "access_token", _secret_key: str = None, - auth_request_headers: dict = None, + auth_request_headers: dict = {}, token_request_headers: dict = None, scope: str = None, ): @@ -58,7 +58,7 @@ def __init__( home_suffix (str, optional): The route your dash application will sit, relative to your url. Defaults to "/home". _token_field_name (str, optional): The key for the token returned in JSON from the token endpoint. Defaults to "access_token". _secret_key (str, optional): Secret key for flask app, normally generated at runtime. Defaults to None. - auth_request_params (dict, optional): Additional params to send to the authorization endpoint. Defaults to None. + auth_request_headers (dict, optional): Additional params to send to the authorization endpoint. Defaults to {}. token_request_headers (dict, optional): Additional headers to send to the access token endpoint. Defaults to None. scope (str, optional): Header required by most Oauth2 Providers. Defaults to None. @@ -93,6 +93,7 @@ def __init__( _home_suffix=home_suffix, token_request_headers=token_request_headers, _token_field_name=_token_field_name, + with_pkce=with_pkce ) self.server = app diff --git a/dash_auth_external/routes.py b/dash_auth_external/routes.py index 3187ece..91a31a7 100644 --- a/dash_auth_external/routes.py +++ b/dash_auth_external/routes.py @@ -65,7 +65,7 @@ def get_auth_code(): def build_token_body( - url: str, redirect_uri: str, client_id: str, with_pkce: bool, client_secret: str + url: str, redirect_uri: str, client_id: str, with_pkce: bool ): query = urllib.parse.urlparse(url).query redirect_params = urllib.parse.parse_qs(query) From af5ac8fb66980af8a2ed9ac172d4802e238073cc Mon Sep 17 00:00:00 2001 From: Yoann Moranville Date: Thu, 12 Jan 2023 09:11:25 +0100 Subject: [PATCH 2/8] Revert a default to None, never use a function with mutable default arguments --- dash_auth_external/auth.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dash_auth_external/auth.py b/dash_auth_external/auth.py index 28b3115..d6227c8 100644 --- a/dash_auth_external/auth.py +++ b/dash_auth_external/auth.py @@ -41,7 +41,7 @@ def __init__( home_suffix="/home", _token_field_name: str = "access_token", _secret_key: str = None, - auth_request_headers: dict = {}, + auth_request_headers: dict = None, token_request_headers: dict = None, scope: str = None, ): @@ -58,7 +58,7 @@ def __init__( home_suffix (str, optional): The route your dash application will sit, relative to your url. Defaults to "/home". _token_field_name (str, optional): The key for the token returned in JSON from the token endpoint. Defaults to "access_token". _secret_key (str, optional): Secret key for flask app, normally generated at runtime. Defaults to None. - auth_request_headers (dict, optional): Additional params to send to the authorization endpoint. Defaults to {}. + auth_request_headers (dict, optional): Additional params to send to the authorization endpoint. Defaults to None. token_request_headers (dict, optional): Additional headers to send to the access token endpoint. Defaults to None. scope (str, optional): Header required by most Oauth2 Providers. Defaults to None. From 8ec609e91b3d017b2edbb012b49b5fda774ef2f5 Mon Sep 17 00:00:00 2001 From: James Holcombe Date: Sun, 15 Jan 2023 13:45:02 +0000 Subject: [PATCH 3/8] add optional kwargs, ensure optional args only at top level methods --- dash_auth_external/auth.py | 13 ++++++++++--- dash_auth_external/routes.py | 18 +++++++----------- setup.py | 10 ++++------ 3 files changed, 21 insertions(+), 20 deletions(-) diff --git a/dash_auth_external/auth.py b/dash_auth_external/auth.py index fd0ae5e..42f6ac8 100644 --- a/dash_auth_external/auth.py +++ b/dash_auth_external/auth.py @@ -45,7 +45,8 @@ def __init__( token_request_headers: dict = None, scope: str = None, _server_name: str = __name__, - _static_folder: str = './assets/' + _static_folder: str = "./assets/", + **kwargs: dict, ): """The interface for obtaining access tokens from 3rd party OAuth2 Providers. @@ -65,11 +66,17 @@ def __init__( scope (str, optional): Header required by most Oauth2 Providers. Defaults to None. _server_name (str, optional): The name of the Flask Server. Defaults to __name__, so the name of this library. _static_folder (str, optional): The folder with static assets. Defaults to "./assets/". + **kwargs: Additional keyword arguments to pass to the Flask server. See Flask documentation for more information. Returns: DashAuthExternal: Main package class """ - app = Flask(_server_name, instance_relative_config=False, static_folder=_static_folder) + app = Flask( + _server_name, + instance_relative_config=False, + static_folder=_static_folder, + **kwargs, + ) if _secret_key is None: app.secret_key = self.generate_secret_key() @@ -97,7 +104,7 @@ def __init__( _home_suffix=home_suffix, token_request_headers=token_request_headers, _token_field_name=_token_field_name, - with_pkce=with_pkce + with_pkce=with_pkce, ) self.server = app diff --git a/dash_auth_external/routes.py b/dash_auth_external/routes.py index 91a31a7..e346ee8 100644 --- a/dash_auth_external/routes.py +++ b/dash_auth_external/routes.py @@ -12,8 +12,7 @@ def make_code_challenge(length: int = 40): - code_verifier = base64.urlsafe_b64encode( - os.urandom(length)).decode("utf-8") + code_verifier = base64.urlsafe_b64encode(os.urandom(length)).decode("utf-8") code_verifier = re.sub("[^a-zA-Z0-9]+", "", code_verifier) code_challenge = hashlib.sha256(code_verifier.encode("utf-8")).digest() code_challenge = base64.urlsafe_b64encode(code_challenge).decode("utf-8") @@ -27,9 +26,9 @@ def make_auth_route( client_id: str, auth_suffix: str, redirect_uri: str, - with_pkce: bool = True, - scope: str = None, - auth_request_params: dict = None, + with_pkce: bool, + scope: str, + auth_request_params: dict, ): @app.route(auth_suffix) def get_auth_code(): @@ -64,9 +63,7 @@ def get_auth_code(): return app -def build_token_body( - url: str, redirect_uri: str, client_id: str, with_pkce: bool -): +def build_token_body(url: str, redirect_uri: str, client_id: str, with_pkce: bool): query = urllib.parse.urlparse(url).query redirect_params = urllib.parse.parse_qs(query) code = redirect_params["code"][0] @@ -94,8 +91,8 @@ def make_access_token_route( redirect_uri: str, client_id: str, _token_field_name: str, - with_pkce: bool = True, - token_request_headers: dict = None, + with_pkce: bool, + token_request_headers: dict, ): @app.route(redirect_suffix, methods=["GET", "POST"]) def get_token(): @@ -105,7 +102,6 @@ def get_token(): redirect_uri=redirect_uri, with_pkce=with_pkce, client_id=client_id, - ) response_data = get_token_response_data( diff --git a/setup.py b/setup.py index f2652c0..ea2432c 100644 --- a/setup.py +++ b/setup.py @@ -5,12 +5,10 @@ NAME = "dash-auth-external" this_directory = Path(__file__).parent -long_description = "Integrate your dashboards with 3rd party APIs and external OAuth providers." -requires = [ -"dash >= 2.0.0", -"requests >= 1.0.0", -"requests-oauthlib >= 0.3.0" -] +long_description = ( + "Integrate your dashboards with 3rd party APIs and external OAuth providers." +) +requires = ["dash >= 2.0.0", "requests >= 1.0.0", "requests-oauthlib >= 0.3.0"] setup( name=NAME, From b6469682576a89fc9adbbf083398957421c57696 Mon Sep 17 00:00:00 2001 From: James Holcombe Date: Sun, 15 Jan 2023 18:39:47 +0000 Subject: [PATCH 4/8] add refresh token --- dash_auth_external/auth.py | 45 +++++++++++++++++++++++++++----- dash_auth_external/config.py | 1 + dash_auth_external/exceptions.py | 4 +++ dash_auth_external/routes.py | 35 ++++++++++++++++++++----- dash_auth_external/token.py | 16 ++++++++++++ tests/test_app.py | 2 ++ 6 files changed, 89 insertions(+), 14 deletions(-) create mode 100644 dash_auth_external/config.py create mode 100644 dash_auth_external/exceptions.py create mode 100644 dash_auth_external/token.py diff --git a/dash_auth_external/auth.py b/dash_auth_external/auth.py index 42f6ac8..48de7ba 100644 --- a/dash_auth_external/auth.py +++ b/dash_auth_external/auth.py @@ -1,9 +1,13 @@ from flask import Flask import flask from werkzeug.routing import RoutingException, ValidationError -from .routes import make_access_token_route, make_auth_route +from .routes import make_access_token_route, make_auth_route, refresh_token from urllib.parse import urljoin import os +from dash_auth_external.token import OAuth2Token +from dash_auth_external.config import FLASK_HEADER_TOKEN_KEY +from dash_auth_external.exceptions import TokenExpiredError +import json class DashAuthExternal: @@ -17,17 +21,38 @@ def generate_secret_key(length: int = 24) -> str: return os.urandom(length) def get_token(self) -> str: - """Retrieves the access token from flask request headers, using the token cookie given on __init__. + """Attempts to get a valid access token. Returns: str: Bearer Access token from your OAuth2 Provider """ - token = flask.request.headers.get(self._token_field_name) - if token is None: - raise KeyError( - f"Header with name {self._token_field_name} not found in the flask request headers." + + if self.token_data is not None: + if not self.token_data.is_expired(): + return self.token_data.access_token + + if not self.token_data.refresh_token: + raise TokenExpiredError( + "Token is expired and no refresh token available to refresh token." + ) + + self.token_data = refresh_token( + self.external_token_url, self.token_data, self.token_request_headers ) - return token + return self.token_data.access_token + + token_data = flask.request.headers.get(FLASK_HEADER_TOKEN_KEY) + token_data = json.loads(token_data) + if token_data is None: + raise ValueError("No token found in request headers.") + self.token_data = OAuth2Token( + access_token=token_data["access_token"], + refresh_token=token_data.get("refresh_token"), + expires_in=token_data.get("expires_in"), + token_type=token_data.get("token_type"), + ) + + return self.token_data.access_token def __init__( self, @@ -71,6 +96,9 @@ def __init__( Returns: DashAuthExternal: Main package class """ + + self.token_data: OAuth2Token = None + app = Flask( _server_name, instance_relative_config=False, @@ -112,3 +140,6 @@ def __init__( self.redirect_suffix = redirect_suffix self.auth_suffix = auth_suffix self._token_field_name = _token_field_name + self.client_id = client_id + self.external_token_url = external_token_url + self.token_request_headers = token_request_headers diff --git a/dash_auth_external/config.py b/dash_auth_external/config.py new file mode 100644 index 0000000..9869d88 --- /dev/null +++ b/dash_auth_external/config.py @@ -0,0 +1 @@ +FLASK_HEADER_TOKEN_KEY = "TokenDataDashAuthExternal" diff --git a/dash_auth_external/exceptions.py b/dash_auth_external/exceptions.py new file mode 100644 index 0000000..26066bb --- /dev/null +++ b/dash_auth_external/exceptions.py @@ -0,0 +1,4 @@ +class TokenExpiredError(Exception): + """Exception raised when an expired token is encountered.""" + + pass diff --git a/dash_auth_external/routes.py b/dash_auth_external/routes.py index e346ee8..e3feef1 100644 --- a/dash_auth_external/routes.py +++ b/dash_auth_external/routes.py @@ -8,6 +8,9 @@ import hashlib from requests_oauthlib import OAuth2Session +from dash_auth_external.config import FLASK_HEADER_TOKEN_KEY +from dash_auth_external.token import OAuth2Token + os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "1" @@ -104,13 +107,15 @@ def get_token(): client_id=client_id, ) - response_data = get_token_response_data( - external_token_url, body, token_request_headers - ) + response_data = token_request( + url=external_token_url, + body=body, + headers=token_request_headers, + ).json() token = response_data[_token_field_name] response = redirect(_home_suffix) - response.headers.add(_token_field_name, token) + response.headers.add(FLASK_HEADER_TOKEN_KEY, token) return response return app @@ -126,6 +131,22 @@ def token_request(url: str, body: dict, headers: dict): return r -def get_token_response_data(*args): - r = token_request(*args) - return r.json() +def refresh_token(url: str, token_data: OAuth2Token, headers: dict) -> OAuth2Token: + + body = { + "grant_type": "refresh_token", + "refresh_token": token_data.refresh_token, + } + r = token_request(url, body, headers) + r.raise_for_status() + data = r.json() + token_data.access_token = data["access_token"] + + # If the provider does not return a new refresh token, use the old one. + if "refresh_token" in data: + token_data.refresh_token = data["refresh_token"] + + if "expires_in" in data: + token_data.expires_in = data["expires_in"] + + return token_data diff --git a/dash_auth_external/token.py b/dash_auth_external/token.py new file mode 100644 index 0000000..d3e30bc --- /dev/null +++ b/dash_auth_external/token.py @@ -0,0 +1,16 @@ +from dataclasses import dataclass +import time + + +@dataclass +class OAuth2Token: + access_token: str + refresh_token: str + token_type: str + expires_in: int + + def __post_init__(self): + self.expires_at = time.time() + self.expires_in + + def is_expired(self): + return time.time() > self.expires_at if self.expires_at else False diff --git a/tests/test_app.py b/tests/test_app.py index 2a4d6cf..67475b0 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -5,6 +5,7 @@ import unittest from flask import request from .test_config import EXERNAL_TOKEN_URL, EXTERNAL_AUTH_URL, CLIENT_ID +from pytest_mock import mocker """Module for integation tests """ @@ -18,6 +19,7 @@ def test_get_token(mock_post, mock_body): mock_post.return_value = {auth._token_field_name: "ey.asdfasdfasfd"} mock_body.return_value = dict() + # mocking the two helper functions called within the view function for the redirect to home suffix. with app.test_client() as client: From 6b8e558e98a18ab597b853d3d89f46551a2b7f36 Mon Sep 17 00:00:00 2001 From: James Holcombe Date: Sun, 15 Jan 2023 18:39:47 +0000 Subject: [PATCH 5/8] add refresh token --- dash_auth_external/auth.py | 45 +++++++++++++++++++++++++++----- dash_auth_external/config.py | 1 + dash_auth_external/exceptions.py | 4 +++ dash_auth_external/routes.py | 40 ++++++++++++++++++++++------ dash_auth_external/token.py | 16 ++++++++++++ tests/test_app.py | 12 +++------ 6 files changed, 95 insertions(+), 23 deletions(-) create mode 100644 dash_auth_external/config.py create mode 100644 dash_auth_external/exceptions.py create mode 100644 dash_auth_external/token.py diff --git a/dash_auth_external/auth.py b/dash_auth_external/auth.py index 42f6ac8..48de7ba 100644 --- a/dash_auth_external/auth.py +++ b/dash_auth_external/auth.py @@ -1,9 +1,13 @@ from flask import Flask import flask from werkzeug.routing import RoutingException, ValidationError -from .routes import make_access_token_route, make_auth_route +from .routes import make_access_token_route, make_auth_route, refresh_token from urllib.parse import urljoin import os +from dash_auth_external.token import OAuth2Token +from dash_auth_external.config import FLASK_HEADER_TOKEN_KEY +from dash_auth_external.exceptions import TokenExpiredError +import json class DashAuthExternal: @@ -17,17 +21,38 @@ def generate_secret_key(length: int = 24) -> str: return os.urandom(length) def get_token(self) -> str: - """Retrieves the access token from flask request headers, using the token cookie given on __init__. + """Attempts to get a valid access token. Returns: str: Bearer Access token from your OAuth2 Provider """ - token = flask.request.headers.get(self._token_field_name) - if token is None: - raise KeyError( - f"Header with name {self._token_field_name} not found in the flask request headers." + + if self.token_data is not None: + if not self.token_data.is_expired(): + return self.token_data.access_token + + if not self.token_data.refresh_token: + raise TokenExpiredError( + "Token is expired and no refresh token available to refresh token." + ) + + self.token_data = refresh_token( + self.external_token_url, self.token_data, self.token_request_headers ) - return token + return self.token_data.access_token + + token_data = flask.request.headers.get(FLASK_HEADER_TOKEN_KEY) + token_data = json.loads(token_data) + if token_data is None: + raise ValueError("No token found in request headers.") + self.token_data = OAuth2Token( + access_token=token_data["access_token"], + refresh_token=token_data.get("refresh_token"), + expires_in=token_data.get("expires_in"), + token_type=token_data.get("token_type"), + ) + + return self.token_data.access_token def __init__( self, @@ -71,6 +96,9 @@ def __init__( Returns: DashAuthExternal: Main package class """ + + self.token_data: OAuth2Token = None + app = Flask( _server_name, instance_relative_config=False, @@ -112,3 +140,6 @@ def __init__( self.redirect_suffix = redirect_suffix self.auth_suffix = auth_suffix self._token_field_name = _token_field_name + self.client_id = client_id + self.external_token_url = external_token_url + self.token_request_headers = token_request_headers diff --git a/dash_auth_external/config.py b/dash_auth_external/config.py new file mode 100644 index 0000000..9869d88 --- /dev/null +++ b/dash_auth_external/config.py @@ -0,0 +1 @@ +FLASK_HEADER_TOKEN_KEY = "TokenDataDashAuthExternal" diff --git a/dash_auth_external/exceptions.py b/dash_auth_external/exceptions.py new file mode 100644 index 0000000..26066bb --- /dev/null +++ b/dash_auth_external/exceptions.py @@ -0,0 +1,4 @@ +class TokenExpiredError(Exception): + """Exception raised when an expired token is encountered.""" + + pass diff --git a/dash_auth_external/routes.py b/dash_auth_external/routes.py index e346ee8..abfba3a 100644 --- a/dash_auth_external/routes.py +++ b/dash_auth_external/routes.py @@ -8,11 +8,15 @@ import hashlib from requests_oauthlib import OAuth2Session +from dash_auth_external.config import FLASK_HEADER_TOKEN_KEY +from dash_auth_external.token import OAuth2Token + os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "1" def make_code_challenge(length: int = 40): - code_verifier = base64.urlsafe_b64encode(os.urandom(length)).decode("utf-8") + code_verifier = base64.urlsafe_b64encode( + os.urandom(length)).decode("utf-8") code_verifier = re.sub("[^a-zA-Z0-9]+", "", code_verifier) code_challenge = hashlib.sha256(code_verifier.encode("utf-8")).digest() code_challenge = base64.urlsafe_b64encode(code_challenge).decode("utf-8") @@ -42,6 +46,8 @@ def get_auth_code(): scope=scope, ) + print("with_pkce", with_pkce) + if with_pkce: code_challenge, code_verifier = make_code_challenge() session["cv"] = code_verifier @@ -104,13 +110,15 @@ def get_token(): client_id=client_id, ) - response_data = get_token_response_data( - external_token_url, body, token_request_headers - ) + response_data = token_request( + url=external_token_url, + body=body, + headers=token_request_headers, + ).json() token = response_data[_token_field_name] response = redirect(_home_suffix) - response.headers.add(_token_field_name, token) + response.headers.add(FLASK_HEADER_TOKEN_KEY, token) return response return app @@ -126,6 +134,22 @@ def token_request(url: str, body: dict, headers: dict): return r -def get_token_response_data(*args): - r = token_request(*args) - return r.json() +def refresh_token(url: str, token_data: OAuth2Token, headers: dict) -> OAuth2Token: + + body = { + "grant_type": "refresh_token", + "refresh_token": token_data.refresh_token, + } + r = token_request(url, body, headers) + r.raise_for_status() + data = r.json() + token_data.access_token = data["access_token"] + + # If the provider does not return a new refresh token, use the old one. + if "refresh_token" in data: + token_data.refresh_token = data["refresh_token"] + + if "expires_in" in data: + token_data.expires_in = data["expires_in"] + + return token_data diff --git a/dash_auth_external/token.py b/dash_auth_external/token.py new file mode 100644 index 0000000..d3e30bc --- /dev/null +++ b/dash_auth_external/token.py @@ -0,0 +1,16 @@ +from dataclasses import dataclass +import time + + +@dataclass +class OAuth2Token: + access_token: str + refresh_token: str + token_type: str + expires_in: int + + def __post_init__(self): + self.expires_at = time.time() + self.expires_in + + def is_expired(self): + return time.time() > self.expires_at if self.expires_at else False diff --git a/tests/test_app.py b/tests/test_app.py index 2a4d6cf..5c11eb7 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -5,25 +5,21 @@ import unittest from flask import request from .test_config import EXERNAL_TOKEN_URL, EXTERNAL_AUTH_URL, CLIENT_ID +from pytest_mock import mocker """Module for integation tests """ -@patch("dash_auth_external.routes.build_token_body") -@patch("dash_auth_external.routes.get_token_response_data") -def test_get_token(mock_post, mock_body): - auth = DashAuthExternal(EXTERNAL_AUTH_URL, EXERNAL_TOKEN_URL, CLIENT_ID) +def test_pkce_true(mocker): + auth = DashAuthExternal( + EXTERNAL_AUTH_URL, EXERNAL_TOKEN_URL, CLIENT_ID, with_pkce=True) app = auth.server - mock_post.return_value = {auth._token_field_name: "ey.asdfasdfasfd"} - mock_body.return_value = dict() # mocking the two helper functions called within the view function for the redirect to home suffix. with app.test_client() as client: response = client.get(auth.redirect_suffix) assert response.status_code == 302 - mock_post.assert_called_once() - mock_body.assert_called_once() assert auth._token_field_name in response.headers From 0f269def4beb4572d26af8823a7bde93ca4fe81b Mon Sep 17 00:00:00 2001 From: James Holcombe Date: Sat, 4 Feb 2023 18:37:53 +0000 Subject: [PATCH 6/8] more extensive testing --- README.md | 19 ++--- dash_auth_external/auth.py | 74 +++++++++++++------ dash_auth_external/config.py | 2 +- dash_auth_external/routes.py | 35 ++------- dash_auth_external/token.py | 2 +- examples/usage.py | 4 +- tests/test_app.py | 115 ++++++++++++++++++++++++++--- tests/test_auth.py | 136 +++++++++++++++++++++++++++++++++++ tests/test_routes.py | 39 ---------- 9 files changed, 314 insertions(+), 112 deletions(-) create mode 100644 tests/test_auth.py delete mode 100644 tests/test_routes.py diff --git a/README.md b/README.md index d71190a..f8da722 100644 --- a/README.md +++ b/README.md @@ -40,6 +40,8 @@ That's it! You can now define your layout and callbacks as usual. > **NOTE** This can **ONLY** be done in the context of a dash callback. ```python +... + app.layout = html.Div( [ html.Div(id="example-output"), @@ -51,12 +53,17 @@ Output("example-output", "children"), Input("example-input", "value") ) def example_callback(value): - token = ( - auth.get_token() - ) ##The token can only be retrieved in the context of a dash callback + token = auth.get_token() + ##The token can only be retrieved in the context of a dash callback return token ``` +## Refresh Tokens + +If your OAuth provider supports refresh tokens, these are automatically checked and handled in the _get_token_ method. + +> Check if your OAuth provider requires any additional scopes to support refresh tokens + ## Troubleshooting If you hit 400 responses (bad request) from either endpoint, there are a number of things that might need configuration. @@ -67,10 +74,6 @@ Make sure you have checked the following _The library uses a default redirect URI of http://127.0.0.1:8050/redirect_. -- Check the **key field** for the **token** in the JSON response returned by the token endpoint by your OAuth provider. - -_The default is "access_token" but different OAuth providers may use a different key for this._ - ## Contributing -Contributions, issues, and ideas are all more than welcome +Contributions, issues, and ideas are all more than welcome. diff --git a/dash_auth_external/auth.py b/dash_auth_external/auth.py index 48de7ba..a28bee3 100644 --- a/dash_auth_external/auth.py +++ b/dash_auth_external/auth.py @@ -1,13 +1,24 @@ from flask import Flask -import flask -from werkzeug.routing import RoutingException, ValidationError -from .routes import make_access_token_route, make_auth_route, refresh_token +from .routes import make_access_token_route, make_auth_route, token_request from urllib.parse import urljoin import os from dash_auth_external.token import OAuth2Token -from dash_auth_external.config import FLASK_HEADER_TOKEN_KEY +from dash_auth_external.config import FLASK_SESSION_TOKEN_KEY from dash_auth_external.exceptions import TokenExpiredError -import json + +from flask import session + + +def _get_token_data_from_session() -> dict: + """Gets the token data from the session. + + Returns: + dict: The token data from the session. + """ + token_data = session.get(FLASK_SESSION_TOKEN_KEY) + if token_data is None: + raise ValueError("No token found in request session.") + return token_data class DashAuthExternal: @@ -41,10 +52,8 @@ def get_token(self) -> str: ) return self.token_data.access_token - token_data = flask.request.headers.get(FLASK_HEADER_TOKEN_KEY) - token_data = json.loads(token_data) - if token_data is None: - raise ValueError("No token found in request headers.") + token_data = _get_token_data_from_session() + self.token_data = OAuth2Token( access_token=token_data["access_token"], refresh_token=token_data.get("refresh_token"), @@ -64,14 +73,13 @@ def __init__( redirect_suffix: str = "/redirect", auth_suffix: str = "/", home_suffix="/home", + _flask_server: Flask = None, _token_field_name: str = "access_token", _secret_key: str = None, auth_request_headers: dict = None, token_request_headers: dict = None, scope: str = None, _server_name: str = __name__, - _static_folder: str = "./assets/", - **kwargs: dict, ): """The interface for obtaining access tokens from 3rd party OAuth2 Providers. @@ -84,27 +92,32 @@ def __init__( redirect_suffix (str, optional): The route that OAuth2 provider will redirect back to. Defaults to "/redirect". auth_suffix (str, optional): The route that will trigger the initial redirect to the external OAuth provider. Defaults to "/". home_suffix (str, optional): The route your dash application will sit, relative to your url. Defaults to "/home". + _flask_server (Flask, optional): Flask server to use if additional config required. Defaults to None. _token_field_name (str, optional): The key for the token returned in JSON from the token endpoint. Defaults to "access_token". _secret_key (str, optional): Secret key for flask app, normally generated at runtime. Defaults to None. auth_request_headers (dict, optional): Additional params to send to the authorization endpoint. Defaults to None. token_request_headers (dict, optional): Additional headers to send to the access token endpoint. Defaults to None. scope (str, optional): Header required by most Oauth2 Providers. Defaults to None. - _server_name (str, optional): The name of the Flask Server. Defaults to __name__, so the name of this library. - _static_folder (str, optional): The folder with static assets. Defaults to "./assets/". - **kwargs: Additional keyword arguments to pass to the Flask server. See Flask documentation for more information. + _server_name (str, optional): The name of the Flask Server. Defaults to __name__, ignored if _flask_server is not None. + Returns: DashAuthExternal: Main package class """ self.token_data: OAuth2Token = None + if auth_request_headers is None: + auth_request_headers = {} + if token_request_headers is None: + token_request_headers = {} - app = Flask( - _server_name, - instance_relative_config=False, - static_folder=_static_folder, - **kwargs, - ) + if _flask_server is None: + + app = Flask( + _server_name, instance_relative_config=False, static_folder="./assets" + ) + else: + app = _flask_server if _secret_key is None: app.secret_key = self.generate_secret_key() @@ -143,3 +156,24 @@ def __init__( self.client_id = client_id self.external_token_url = external_token_url self.token_request_headers = token_request_headers + self.scope = scope + + +def refresh_token(url: str, token_data: OAuth2Token, headers: dict) -> OAuth2Token: + + body = { + "grant_type": "refresh_token", + "refresh_token": token_data.refresh_token, + } + data = token_request(url, body, headers) + + token_data.access_token = data["access_token"] + + # If the provider does not return a new refresh token, use the old one. + if "refresh_token" in data: + token_data.refresh_token = data["refresh_token"] + + if "expires_in" in data: + token_data.expires_in = data["expires_in"] + + return token_data diff --git a/dash_auth_external/config.py b/dash_auth_external/config.py index 9869d88..30cf568 100644 --- a/dash_auth_external/config.py +++ b/dash_auth_external/config.py @@ -1 +1 @@ -FLASK_HEADER_TOKEN_KEY = "TokenDataDashAuthExternal" +FLASK_SESSION_TOKEN_KEY = "TokenDataDashAuthExternal" diff --git a/dash_auth_external/routes.py b/dash_auth_external/routes.py index abfba3a..c103e76 100644 --- a/dash_auth_external/routes.py +++ b/dash_auth_external/routes.py @@ -7,9 +7,8 @@ import requests import hashlib from requests_oauthlib import OAuth2Session +from dash_auth_external.config import FLASK_SESSION_TOKEN_KEY -from dash_auth_external.config import FLASK_HEADER_TOKEN_KEY -from dash_auth_external.token import OAuth2Token os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "1" @@ -46,8 +45,6 @@ def get_auth_code(): scope=scope, ) - print("with_pkce", with_pkce) - if with_pkce: code_challenge, code_verifier = make_code_challenge() session["cv"] = code_verifier @@ -101,7 +98,7 @@ def make_access_token_route( token_request_headers: dict, ): @app.route(redirect_suffix, methods=["GET", "POST"]) - def get_token(): + def get_token_route(): url = request.url body = build_token_body( url=url, @@ -114,11 +111,10 @@ def get_token(): url=external_token_url, body=body, headers=token_request_headers, - ).json() - token = response_data[_token_field_name] + ) response = redirect(_home_suffix) - response.headers.add(FLASK_HEADER_TOKEN_KEY, token) + session[FLASK_SESSION_TOKEN_KEY] = response_data return response return app @@ -131,25 +127,4 @@ def token_request(url: str, body: dict, headers: dict): raise requests.RequestException( f"{r.status_code} {r.reason}:The request to the access token endpoint failed." ) - return r - - -def refresh_token(url: str, token_data: OAuth2Token, headers: dict) -> OAuth2Token: - - body = { - "grant_type": "refresh_token", - "refresh_token": token_data.refresh_token, - } - r = token_request(url, body, headers) - r.raise_for_status() - data = r.json() - token_data.access_token = data["access_token"] - - # If the provider does not return a new refresh token, use the old one. - if "refresh_token" in data: - token_data.refresh_token = data["refresh_token"] - - if "expires_in" in data: - token_data.expires_in = data["expires_in"] - - return token_data + return r.json() diff --git a/dash_auth_external/token.py b/dash_auth_external/token.py index d3e30bc..9968882 100644 --- a/dash_auth_external/token.py +++ b/dash_auth_external/token.py @@ -5,9 +5,9 @@ @dataclass class OAuth2Token: access_token: str - refresh_token: str token_type: str expires_in: int + refresh_token: str = None def __post_init__(self): self.expires_at = time.time() + self.expires_in diff --git a/examples/usage.py b/examples/usage.py index f91ee37..1dd7e73 100644 --- a/examples/usage.py +++ b/examples/usage.py @@ -15,7 +15,7 @@ app = Dash(__name__, server=server) # instantiating our app using this server -##Below we can define our dash app like normal +# Below we can define our dash app like normal app.layout = html.Div([html.Div(id="example-output"), dcc.Input(id="example-input")]) @@ -23,7 +23,7 @@ def example_callback(value): token = ( auth.get_token() - ) ##The token can only be retrieved in the context of a dash callback + ) # The token can only be retrieved in the context of a dash callback return token diff --git a/tests/test_app.py b/tests/test_app.py index 5c11eb7..3e85475 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -1,25 +1,118 @@ -import requests +"""Module for integation tests, testing full OAuth2 flow, excluding the get_token method +""" + from dash_auth_external import DashAuthExternal -import pytest -from unittest.mock import Mock, patch -import unittest -from flask import request +from unittest.mock import Mock +from dash_auth_external.config import FLASK_HEADER_TOKEN_KEY from .test_config import EXERNAL_TOKEN_URL, EXTERNAL_AUTH_URL, CLIENT_ID from pytest_mock import mocker - -"""Module for integation tests -""" +from requests_oauthlib import OAuth2Session def test_pkce_true(mocker): auth = DashAuthExternal( - EXTERNAL_AUTH_URL, EXERNAL_TOKEN_URL, CLIENT_ID, with_pkce=True) + EXTERNAL_AUTH_URL, EXERNAL_TOKEN_URL, CLIENT_ID, with_pkce=True + ) app = auth.server + redirect_uri = "http://localhost:8050" + session_mock = Mock( + OAuth2Session(CLIENT_ID, redirect_uri=redirect_uri, scope=auth.scope) + ) + + session_mock.authorization_url.return_value = ("https://example.com", "state") + + mocker.patch("dash_auth_external.routes.OAuth2Session", return_value=session_mock) + + mocker.patch( + "dash_auth_external.routes.make_code_challenge", + return_value=("code_challenge", "code_verifier"), + ) + # mocking the two helper functions called within the view function for the redirect to home suffix. with app.test_client() as client: - response = client.get(auth.redirect_suffix) + + response = client.get(auth.auth_suffix) assert response.status_code == 302 - assert auth._token_field_name in response.headers + # assert that the authorization_url method was called with the correct arguments + session_mock.authorization_url.assert_called_with( + EXTERNAL_AUTH_URL, + code_challenge="code_challenge", + code_challenge_method="S256", + ) + + # user logs in and is redirected to the redirect_uri, with a code and state + + # mocking the token request in the token route + mocker.patch( + "dash_auth_external.routes.token_request", + return_value={ + "access_token": "access_token", + "refresh_token": "refresh_token", + "token_type": "Bearer", + "expires_in": "3599", + }, + ) + + # now we call the token route with the code and state returned from the authorization_url method + response = client.get( + auth.redirect_suffix, query_string={"code": "code", "state": "state"} + ) + assert response.status_code == 302 + with client.session_transaction() as session: + token_data = session[FLASK_HEADER_TOKEN_KEY] + assert token_data["access_token"] == "access_token" + assert token_data["refresh_token"] == "refresh_token" + assert token_data["token_type"] == "Bearer" + assert token_data["expires_in"] == "3599" + + +def test_pkce_false(mocker): + auth = DashAuthExternal( + EXTERNAL_AUTH_URL, EXERNAL_TOKEN_URL, CLIENT_ID, with_pkce=False + ) + app = auth.server + + redirect_uri = "http://localhost:8050" + session_mock = Mock( + OAuth2Session(CLIENT_ID, redirect_uri=redirect_uri, scope=auth.scope) + ) + + session_mock.authorization_url.return_value = ("https://example.com", "state") + + mocker.patch("dash_auth_external.routes.OAuth2Session", return_value=session_mock) + + with app.test_client() as client: + + response = client.get(auth.auth_suffix) + assert response.status_code == 302 + + # assert that the authorization_url method was called with the correct arguments + session_mock.authorization_url.assert_called_with(EXTERNAL_AUTH_URL) + + # user logs in and is redirected to the redirect_uri, with a code and state + + # mocking the token request in the token route + mocker.patch( + "dash_auth_external.routes.token_request", + return_value={ + "access_token": "access_token", + "refresh_token": "refresh_token", + "token_type": "Bearer", + "expires_in": "3599", + }, + ) + + # now we call the token route with the code and state returned from the authorization_url method + response = client.get( + auth.redirect_suffix, query_string={"code": "code", "state": "state"} + ) + assert response.status_code == 302 + with client.session_transaction() as session: + token_data = session[FLASK_HEADER_TOKEN_KEY] + assert token_data["access_token"] == "access_token" + assert token_data["refresh_token"] == "refresh_token" + assert token_data["token_type"] == "Bearer" + assert token_data["expires_in"] == "3599" diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 0000000..5711d1a --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,136 @@ +from dash_auth_external import DashAuthExternal +from unittest.mock import Mock +from dash_auth_external.config import FLASK_HEADER_TOKEN_KEY +from dash_auth_external.exceptions import TokenExpiredError +from dash_auth_external.token import OAuth2Token +from .test_config import EXERNAL_TOKEN_URL, EXTERNAL_AUTH_URL, CLIENT_ID +from pytest_mock import mocker +from requests_oauthlib import OAuth2Session +from dash import Dash, Input, Output, html, dcc +import pytest + + +@pytest.fixture() +def dash_app_and_auth(): + auth = DashAuthExternal(EXTERNAL_AUTH_URL, EXERNAL_TOKEN_URL, CLIENT_ID) + + app = Dash(__name__, server=auth.server) + + app.layout = html.Div([html.Div(id="test-output"), dcc.Input(id="test-input")]) + return app, auth + + +@pytest.fixture() +def access_token_data_with_refresh(): + return { + "access_token": "access_token", + "refresh_token": "refresh_token", + "token_type": "Bearer", + "expires_in": 3599, + } + + +@pytest.fixture() +def expired_access_token_data_without_refresh(): + return { + "access_token": "access_token", + "token_type": "Bearer", + "expires_in": -1, + } + + +@pytest.fixture() +def expired_access_token_data_with_refresh(expired_access_token_data_without_refresh): + return { + **expired_access_token_data_without_refresh, + "refresh_token": "refresh_token", + } + + +def test_get_token_first_call( + dash_app_and_auth, mocker, access_token_data_with_refresh +): + dash_app, auth = dash_app_and_auth + + mocker.patch( + "dash_auth_external.auth._get_token_data_from_session", + return_value=access_token_data_with_refresh, + ) + + @dash_app.callback(Output("test-output", "children"), Input("test-input", "value")) + def test_callback(value): + token = auth.get_token() + return token + + assert auth.token_data is None + test_callback("test") + assert isinstance(auth.token_data, OAuth2Token) + + +def test_get_token_second_call( + dash_app_and_auth, mocker, access_token_data_with_refresh +): + dash_app, auth = dash_app_and_auth + + mocker.patch( + "dash_auth_external.auth._get_token_data_from_session", + return_value=access_token_data_with_refresh, + ) + + @dash_app.callback(Output("test-output", "children"), Input("test-input", "value")) + def test_callback(value): + token = auth.get_token() + return token + + assert auth.token_data is None + test_callback("test") + assert isinstance(auth.token_data, OAuth2Token) + test_callback("test") + assert isinstance(auth.token_data, OAuth2Token) + + +def test_get_token_with_refresh( + dash_app_and_auth, + mocker, + expired_access_token_data_with_refresh, + access_token_data_with_refresh, +): + dash_app, auth = dash_app_and_auth + + refresh_mock = mocker.patch( + "dash_auth_external.auth.refresh_token", + return_value=OAuth2Token(**access_token_data_with_refresh), + ) + + @dash_app.callback(Output("test-output", "children"), Input("test-input", "value")) + def test_callback(value): + token = auth.get_token() + return token + + auth.token_data = OAuth2Token(**expired_access_token_data_with_refresh) + + test_callback("test") + assert isinstance(auth.token_data, OAuth2Token) + assert auth.token_data.expires_in > 0 + refresh_mock.assert_called_once() + + +def test_expired_token_raises_exception( + dash_app_and_auth, mocker, expired_access_token_data_without_refresh +): + dash_app, auth = dash_app_and_auth + + mocker.patch( + "dash_auth_external.auth._get_token_data_from_session", + return_value=expired_access_token_data_without_refresh, + ) + + @dash_app.callback(Output("test-output", "children"), Input("test-input", "value")) + def test_callback(value): + token = auth.get_token() + return token + + auth.token_data = OAuth2Token(**expired_access_token_data_without_refresh) + + with pytest.raises(TokenExpiredError): + test_callback("test") diff --git a/tests/test_routes.py b/tests/test_routes.py deleted file mode 100644 index 910d74d..0000000 --- a/tests/test_routes.py +++ /dev/null @@ -1,39 +0,0 @@ -from .test_context import dash_auth_external -from unittest import mock -import pytest -import requests -from werkzeug.wrappers import request -from dash_auth_external.auth import DashAuthExternal -from dash_auth_external.routes import token_request -import flask - -"""Module for unit tests -""" - - -@mock.patch("dash_auth_external.routes.requests.post") -def test_token_route_ok(mock_post): - - mock_post.return_value.status_code = 200 - response = token_request("Fakeurl", dict(), dict()) - assert response.status_code == 200 - - -@mock.patch("dash_auth_external.routes.requests.post") -def test_token_route_raises(mock_post): - mock_post.return_value.status_code = 400 - - with pytest.raises(requests.RequestException): - response = token_request("Fakeurl", dict(), dict()) - - -def test_auth_route_ok(): - pass - - -def test_auth_route_raises(): - pass - - -def test_make_token_body(): - pass From 3646ac6af552c2aa6531c81112e3a61563f70883 Mon Sep 17 00:00:00 2001 From: James Holcombe Date: Sat, 4 Feb 2023 18:49:22 +0000 Subject: [PATCH 7/8] fix merge conflicts --- dash_auth_external/routes.py | 4 ---- tests/test_auth.py | 4 ++-- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/dash_auth_external/routes.py b/dash_auth_external/routes.py index b27ad89..533c6fb 100644 --- a/dash_auth_external/routes.py +++ b/dash_auth_external/routes.py @@ -9,10 +9,6 @@ from requests_oauthlib import OAuth2Session from dash_auth_external.config import FLASK_SESSION_TOKEN_KEY - -from dash_auth_external.config import FLASK_HEADER_TOKEN_KEY -from dash_auth_external.token import OAuth2Token - os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "1" diff --git a/tests/test_auth.py b/tests/test_auth.py index 5711d1a..c903421 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -1,6 +1,5 @@ from dash_auth_external import DashAuthExternal from unittest.mock import Mock -from dash_auth_external.config import FLASK_HEADER_TOKEN_KEY from dash_auth_external.exceptions import TokenExpiredError from dash_auth_external.token import OAuth2Token from .test_config import EXERNAL_TOKEN_URL, EXTERNAL_AUTH_URL, CLIENT_ID @@ -16,7 +15,8 @@ def dash_app_and_auth(): app = Dash(__name__, server=auth.server) - app.layout = html.Div([html.Div(id="test-output"), dcc.Input(id="test-input")]) + app.layout = html.Div( + [html.Div(id="test-output"), dcc.Input(id="test-input")]) return app, auth From 1670a7b05d6f62a9a6a9510035f5ec2d2d2c8804 Mon Sep 17 00:00:00 2001 From: James Holcombe Date: Sat, 4 Feb 2023 18:49:39 +0000 Subject: [PATCH 8/8] black formatter --- dash_auth_external/routes.py | 3 +-- tests/test_app.py | 18 ++++++------------ tests/test_auth.py | 3 +-- 3 files changed, 8 insertions(+), 16 deletions(-) diff --git a/dash_auth_external/routes.py b/dash_auth_external/routes.py index 533c6fb..e706500 100644 --- a/dash_auth_external/routes.py +++ b/dash_auth_external/routes.py @@ -13,8 +13,7 @@ def make_code_challenge(length: int = 40): - code_verifier = base64.urlsafe_b64encode( - os.urandom(length)).decode("utf-8") + code_verifier = base64.urlsafe_b64encode(os.urandom(length)).decode("utf-8") code_verifier = re.sub("[^a-zA-Z0-9]+", "", code_verifier) code_challenge = hashlib.sha256(code_verifier.encode("utf-8")).digest() code_challenge = base64.urlsafe_b64encode(code_challenge).decode("utf-8") diff --git a/tests/test_app.py b/tests/test_app.py index 4513d67..ad84a91 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -20,11 +20,9 @@ def test_pkce_true(mocker): OAuth2Session(CLIENT_ID, redirect_uri=redirect_uri, scope=auth.scope) ) - session_mock.authorization_url.return_value = ( - "https://example.com", "state") + session_mock.authorization_url.return_value = ("https://example.com", "state") - mocker.patch("dash_auth_external.routes.OAuth2Session", - return_value=session_mock) + mocker.patch("dash_auth_external.routes.OAuth2Session", return_value=session_mock) mocker.patch( "dash_auth_external.routes.make_code_challenge", @@ -60,8 +58,7 @@ def test_pkce_true(mocker): # now we call the token route with the code and state returned from the authorization_url method response = client.get( - auth.redirect_suffix, query_string={ - "code": "code", "state": "state"} + auth.redirect_suffix, query_string={"code": "code", "state": "state"} ) assert response.status_code == 302 with client.session_transaction() as session: @@ -83,11 +80,9 @@ def test_pkce_false(mocker): OAuth2Session(CLIENT_ID, redirect_uri=redirect_uri, scope=auth.scope) ) - session_mock.authorization_url.return_value = ( - "https://example.com", "state") + session_mock.authorization_url.return_value = ("https://example.com", "state") - mocker.patch("dash_auth_external.routes.OAuth2Session", - return_value=session_mock) + mocker.patch("dash_auth_external.routes.OAuth2Session", return_value=session_mock) with app.test_client() as client: @@ -112,8 +107,7 @@ def test_pkce_false(mocker): # now we call the token route with the code and state returned from the authorization_url method response = client.get( - auth.redirect_suffix, query_string={ - "code": "code", "state": "state"} + auth.redirect_suffix, query_string={"code": "code", "state": "state"} ) assert response.status_code == 302 with client.session_transaction() as session: diff --git a/tests/test_auth.py b/tests/test_auth.py index c903421..cf752e9 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -15,8 +15,7 @@ def dash_app_and_auth(): app = Dash(__name__, server=auth.server) - app.layout = html.Div( - [html.Div(id="test-output"), dcc.Input(id="test-input")]) + app.layout = html.Div([html.Div(id="test-output"), dcc.Input(id="test-input")]) return app, auth