diff --git a/README.md b/README.md index d71190a..f8da722 100644 --- a/README.md +++ b/README.md @@ -40,6 +40,8 @@ That's it! You can now define your layout and callbacks as usual. > **NOTE** This can **ONLY** be done in the context of a dash callback. ```python +... + app.layout = html.Div( [ html.Div(id="example-output"), @@ -51,12 +53,17 @@ Output("example-output", "children"), Input("example-input", "value") ) def example_callback(value): - token = ( - auth.get_token() - ) ##The token can only be retrieved in the context of a dash callback + token = auth.get_token() + ##The token can only be retrieved in the context of a dash callback return token ``` +## Refresh Tokens + +If your OAuth provider supports refresh tokens, these are automatically checked and handled in the _get_token_ method. + +> Check if your OAuth provider requires any additional scopes to support refresh tokens + ## Troubleshooting If you hit 400 responses (bad request) from either endpoint, there are a number of things that might need configuration. @@ -67,10 +74,6 @@ Make sure you have checked the following _The library uses a default redirect URI of http://127.0.0.1:8050/redirect_. -- Check the **key field** for the **token** in the JSON response returned by the token endpoint by your OAuth provider. - -_The default is "access_token" but different OAuth providers may use a different key for this._ - ## Contributing -Contributions, issues, and ideas are all more than welcome +Contributions, issues, and ideas are all more than welcome. diff --git a/dash_auth_external/auth.py b/dash_auth_external/auth.py index 4c55494..a28bee3 100644 --- a/dash_auth_external/auth.py +++ b/dash_auth_external/auth.py @@ -1,9 +1,24 @@ from flask import Flask -import flask -from werkzeug.routing import RoutingException, ValidationError -from .routes import make_access_token_route, make_auth_route +from .routes import make_access_token_route, make_auth_route, token_request from urllib.parse import urljoin import os +from dash_auth_external.token import OAuth2Token +from dash_auth_external.config import FLASK_SESSION_TOKEN_KEY +from dash_auth_external.exceptions import TokenExpiredError + +from flask import session + + +def _get_token_data_from_session() -> dict: + """Gets the token data from the session. + + Returns: + dict: The token data from the session. + """ + token_data = session.get(FLASK_SESSION_TOKEN_KEY) + if token_data is None: + raise ValueError("No token found in request session.") + return token_data class DashAuthExternal: @@ -17,17 +32,36 @@ def generate_secret_key(length: int = 24) -> str: return os.urandom(length) def get_token(self) -> str: - """Retrieves the access token from flask request headers, using the token cookie given on __init__. + """Attempts to get a valid access token. Returns: str: Bearer Access token from your OAuth2 Provider """ - token = flask.request.headers.get(self._token_field_name) - if token is None: - raise KeyError( - f"Header with name {self._token_field_name} not found in the flask request headers." + + if self.token_data is not None: + if not self.token_data.is_expired(): + return self.token_data.access_token + + if not self.token_data.refresh_token: + raise TokenExpiredError( + "Token is expired and no refresh token available to refresh token." + ) + + self.token_data = refresh_token( + self.external_token_url, self.token_data, self.token_request_headers ) - return token + return self.token_data.access_token + + token_data = _get_token_data_from_session() + + self.token_data = OAuth2Token( + access_token=token_data["access_token"], + refresh_token=token_data.get("refresh_token"), + expires_in=token_data.get("expires_in"), + token_type=token_data.get("token_type"), + ) + + return self.token_data.access_token def __init__( self, @@ -39,13 +73,13 @@ def __init__( redirect_suffix: str = "/redirect", auth_suffix: str = "/", home_suffix="/home", + _flask_server: Flask = None, _token_field_name: str = "access_token", _secret_key: str = None, auth_request_headers: dict = None, token_request_headers: dict = None, scope: str = None, _server_name: str = __name__, - _static_folder: str = './assets/' ): """The interface for obtaining access tokens from 3rd party OAuth2 Providers. @@ -58,18 +92,32 @@ def __init__( redirect_suffix (str, optional): The route that OAuth2 provider will redirect back to. Defaults to "/redirect". auth_suffix (str, optional): The route that will trigger the initial redirect to the external OAuth provider. Defaults to "/". home_suffix (str, optional): The route your dash application will sit, relative to your url. Defaults to "/home". + _flask_server (Flask, optional): Flask server to use if additional config required. Defaults to None. _token_field_name (str, optional): The key for the token returned in JSON from the token endpoint. Defaults to "access_token". _secret_key (str, optional): Secret key for flask app, normally generated at runtime. Defaults to None. - auth_request_params (dict, optional): Additional params to send to the authorization endpoint. Defaults to None. + auth_request_headers (dict, optional): Additional params to send to the authorization endpoint. Defaults to None. token_request_headers (dict, optional): Additional headers to send to the access token endpoint. Defaults to None. scope (str, optional): Header required by most Oauth2 Providers. Defaults to None. - _server_name (str, optional): The name of the Flask Server. Defaults to __name__, so the name of this library. - _static_folder (str, optional): The folder with static assets. Defaults to "./assets/". + _server_name (str, optional): The name of the Flask Server. Defaults to __name__, ignored if _flask_server is not None. + Returns: DashAuthExternal: Main package class """ - app = Flask(_server_name, instance_relative_config=False, static_folder=_static_folder) + + self.token_data: OAuth2Token = None + if auth_request_headers is None: + auth_request_headers = {} + if token_request_headers is None: + token_request_headers = {} + + if _flask_server is None: + + app = Flask( + _server_name, instance_relative_config=False, static_folder="./assets" + ) + else: + app = _flask_server if _secret_key is None: app.secret_key = self.generate_secret_key() @@ -97,6 +145,7 @@ def __init__( _home_suffix=home_suffix, token_request_headers=token_request_headers, _token_field_name=_token_field_name, + with_pkce=with_pkce, ) self.server = app @@ -104,3 +153,27 @@ def __init__( self.redirect_suffix = redirect_suffix self.auth_suffix = auth_suffix self._token_field_name = _token_field_name + self.client_id = client_id + self.external_token_url = external_token_url + self.token_request_headers = token_request_headers + self.scope = scope + + +def refresh_token(url: str, token_data: OAuth2Token, headers: dict) -> OAuth2Token: + + body = { + "grant_type": "refresh_token", + "refresh_token": token_data.refresh_token, + } + data = token_request(url, body, headers) + + token_data.access_token = data["access_token"] + + # If the provider does not return a new refresh token, use the old one. + if "refresh_token" in data: + token_data.refresh_token = data["refresh_token"] + + if "expires_in" in data: + token_data.expires_in = data["expires_in"] + + return token_data diff --git a/dash_auth_external/config.py b/dash_auth_external/config.py new file mode 100644 index 0000000..30cf568 --- /dev/null +++ b/dash_auth_external/config.py @@ -0,0 +1 @@ +FLASK_SESSION_TOKEN_KEY = "TokenDataDashAuthExternal" diff --git a/dash_auth_external/exceptions.py b/dash_auth_external/exceptions.py new file mode 100644 index 0000000..26066bb --- /dev/null +++ b/dash_auth_external/exceptions.py @@ -0,0 +1,4 @@ +class TokenExpiredError(Exception): + """Exception raised when an expired token is encountered.""" + + pass diff --git a/dash_auth_external/routes.py b/dash_auth_external/routes.py index 3187ece..e706500 100644 --- a/dash_auth_external/routes.py +++ b/dash_auth_external/routes.py @@ -7,13 +7,13 @@ import requests import hashlib from requests_oauthlib import OAuth2Session +from dash_auth_external.config import FLASK_SESSION_TOKEN_KEY os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "1" def make_code_challenge(length: int = 40): - code_verifier = base64.urlsafe_b64encode( - os.urandom(length)).decode("utf-8") + code_verifier = base64.urlsafe_b64encode(os.urandom(length)).decode("utf-8") code_verifier = re.sub("[^a-zA-Z0-9]+", "", code_verifier) code_challenge = hashlib.sha256(code_verifier.encode("utf-8")).digest() code_challenge = base64.urlsafe_b64encode(code_challenge).decode("utf-8") @@ -27,9 +27,9 @@ def make_auth_route( client_id: str, auth_suffix: str, redirect_uri: str, - with_pkce: bool = True, - scope: str = None, - auth_request_params: dict = None, + with_pkce: bool, + scope: str, + auth_request_params: dict, ): @app.route(auth_suffix) def get_auth_code(): @@ -64,9 +64,7 @@ def get_auth_code(): return app -def build_token_body( - url: str, redirect_uri: str, client_id: str, with_pkce: bool, client_secret: str -): +def build_token_body(url: str, redirect_uri: str, client_id: str, with_pkce: bool): query = urllib.parse.urlparse(url).query redirect_params = urllib.parse.parse_qs(query) code = redirect_params["code"][0] @@ -94,27 +92,27 @@ def make_access_token_route( redirect_uri: str, client_id: str, _token_field_name: str, - with_pkce: bool = True, - token_request_headers: dict = None, + with_pkce: bool, + token_request_headers: dict, ): @app.route(redirect_suffix, methods=["GET", "POST"]) - def get_token(): + def get_token_route(): url = request.url body = build_token_body( url=url, redirect_uri=redirect_uri, with_pkce=with_pkce, client_id=client_id, - ) - response_data = get_token_response_data( - external_token_url, body, token_request_headers + response_data = token_request( + url=external_token_url, + body=body, + headers=token_request_headers, ) - token = response_data[_token_field_name] response = redirect(_home_suffix) - response.headers.add(_token_field_name, token) + session[FLASK_SESSION_TOKEN_KEY] = response_data return response return app @@ -127,9 +125,4 @@ def token_request(url: str, body: dict, headers: dict): raise requests.RequestException( f"{r.status_code} {r.reason}:The request to the access token endpoint failed." ) - return r - - -def get_token_response_data(*args): - r = token_request(*args) return r.json() diff --git a/dash_auth_external/token.py b/dash_auth_external/token.py new file mode 100644 index 0000000..9968882 --- /dev/null +++ b/dash_auth_external/token.py @@ -0,0 +1,16 @@ +from dataclasses import dataclass +import time + + +@dataclass +class OAuth2Token: + access_token: str + token_type: str + expires_in: int + refresh_token: str = None + + def __post_init__(self): + self.expires_at = time.time() + self.expires_in + + def is_expired(self): + return time.time() > self.expires_at if self.expires_at else False diff --git a/examples/usage.py b/examples/usage.py index f91ee37..1dd7e73 100644 --- a/examples/usage.py +++ b/examples/usage.py @@ -15,7 +15,7 @@ app = Dash(__name__, server=server) # instantiating our app using this server -##Below we can define our dash app like normal +# Below we can define our dash app like normal app.layout = html.Div([html.Div(id="example-output"), dcc.Input(id="example-input")]) @@ -23,7 +23,7 @@ def example_callback(value): token = ( auth.get_token() - ) ##The token can only be retrieved in the context of a dash callback + ) # The token can only be retrieved in the context of a dash callback return token diff --git a/setup.py b/setup.py index f2652c0..ea2432c 100644 --- a/setup.py +++ b/setup.py @@ -5,12 +5,10 @@ NAME = "dash-auth-external" this_directory = Path(__file__).parent -long_description = "Integrate your dashboards with 3rd party APIs and external OAuth providers." -requires = [ -"dash >= 2.0.0", -"requests >= 1.0.0", -"requests-oauthlib >= 0.3.0" -] +long_description = ( + "Integrate your dashboards with 3rd party APIs and external OAuth providers." +) +requires = ["dash >= 2.0.0", "requests >= 1.0.0", "requests-oauthlib >= 0.3.0"] setup( name=NAME, diff --git a/tests/test_app.py b/tests/test_app.py index 2a4d6cf..ad84a91 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -1,29 +1,118 @@ -import requests +"""Module for integation tests, testing full OAuth2 flow, excluding the get_token method +""" + from dash_auth_external import DashAuthExternal -import pytest -from unittest.mock import Mock, patch -import unittest -from flask import request +from unittest.mock import Mock +from dash_auth_external.config import FLASK_SESSION_TOKEN_KEY from .test_config import EXERNAL_TOKEN_URL, EXTERNAL_AUTH_URL, CLIENT_ID - -"""Module for integation tests -""" +from pytest_mock import mocker +from requests_oauthlib import OAuth2Session -@patch("dash_auth_external.routes.build_token_body") -@patch("dash_auth_external.routes.get_token_response_data") -def test_get_token(mock_post, mock_body): - auth = DashAuthExternal(EXTERNAL_AUTH_URL, EXERNAL_TOKEN_URL, CLIENT_ID) +def test_pkce_true(mocker): + auth = DashAuthExternal( + EXTERNAL_AUTH_URL, EXERNAL_TOKEN_URL, CLIENT_ID, with_pkce=True + ) app = auth.server - mock_post.return_value = {auth._token_field_name: "ey.asdfasdfasfd"} - mock_body.return_value = dict() + redirect_uri = "http://localhost:8050" + session_mock = Mock( + OAuth2Session(CLIENT_ID, redirect_uri=redirect_uri, scope=auth.scope) + ) + + session_mock.authorization_url.return_value = ("https://example.com", "state") + + mocker.patch("dash_auth_external.routes.OAuth2Session", return_value=session_mock) + + mocker.patch( + "dash_auth_external.routes.make_code_challenge", + return_value=("code_challenge", "code_verifier"), + ) + # mocking the two helper functions called within the view function for the redirect to home suffix. with app.test_client() as client: - response = client.get(auth.redirect_suffix) + + response = client.get(auth.auth_suffix) + assert response.status_code == 302 + + # assert that the authorization_url method was called with the correct arguments + session_mock.authorization_url.assert_called_with( + EXTERNAL_AUTH_URL, + code_challenge="code_challenge", + code_challenge_method="S256", + ) + + # user logs in and is redirected to the redirect_uri, with a code and state + + # mocking the token request in the token route + mocker.patch( + "dash_auth_external.routes.token_request", + return_value={ + "access_token": "access_token", + "refresh_token": "refresh_token", + "token_type": "Bearer", + "expires_in": "3599", + }, + ) + + # now we call the token route with the code and state returned from the authorization_url method + response = client.get( + auth.redirect_suffix, query_string={"code": "code", "state": "state"} + ) + assert response.status_code == 302 + with client.session_transaction() as session: + token_data = session[FLASK_SESSION_TOKEN_KEY] + assert token_data["access_token"] == "access_token" + assert token_data["refresh_token"] == "refresh_token" + assert token_data["token_type"] == "Bearer" + assert token_data["expires_in"] == "3599" + + +def test_pkce_false(mocker): + auth = DashAuthExternal( + EXTERNAL_AUTH_URL, EXERNAL_TOKEN_URL, CLIENT_ID, with_pkce=False + ) + app = auth.server + + redirect_uri = "http://localhost:8050" + session_mock = Mock( + OAuth2Session(CLIENT_ID, redirect_uri=redirect_uri, scope=auth.scope) + ) + + session_mock.authorization_url.return_value = ("https://example.com", "state") + + mocker.patch("dash_auth_external.routes.OAuth2Session", return_value=session_mock) + + with app.test_client() as client: + + response = client.get(auth.auth_suffix) assert response.status_code == 302 - mock_post.assert_called_once() - mock_body.assert_called_once() - assert auth._token_field_name in response.headers + # assert that the authorization_url method was called with the correct arguments + session_mock.authorization_url.assert_called_with(EXTERNAL_AUTH_URL) + + # user logs in and is redirected to the redirect_uri, with a code and state + + # mocking the token request in the token route + mocker.patch( + "dash_auth_external.routes.token_request", + return_value={ + "access_token": "access_token", + "refresh_token": "refresh_token", + "token_type": "Bearer", + "expires_in": "3599", + }, + ) + + # now we call the token route with the code and state returned from the authorization_url method + response = client.get( + auth.redirect_suffix, query_string={"code": "code", "state": "state"} + ) + assert response.status_code == 302 + with client.session_transaction() as session: + token_data = session[FLASK_SESSION_TOKEN_KEY] + assert token_data["access_token"] == "access_token" + assert token_data["refresh_token"] == "refresh_token" + assert token_data["token_type"] == "Bearer" + assert token_data["expires_in"] == "3599" diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 0000000..cf752e9 --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,135 @@ +from dash_auth_external import DashAuthExternal +from unittest.mock import Mock +from dash_auth_external.exceptions import TokenExpiredError +from dash_auth_external.token import OAuth2Token +from .test_config import EXERNAL_TOKEN_URL, EXTERNAL_AUTH_URL, CLIENT_ID +from pytest_mock import mocker +from requests_oauthlib import OAuth2Session +from dash import Dash, Input, Output, html, dcc +import pytest + + +@pytest.fixture() +def dash_app_and_auth(): + auth = DashAuthExternal(EXTERNAL_AUTH_URL, EXERNAL_TOKEN_URL, CLIENT_ID) + + app = Dash(__name__, server=auth.server) + + app.layout = html.Div([html.Div(id="test-output"), dcc.Input(id="test-input")]) + return app, auth + + +@pytest.fixture() +def access_token_data_with_refresh(): + return { + "access_token": "access_token", + "refresh_token": "refresh_token", + "token_type": "Bearer", + "expires_in": 3599, + } + + +@pytest.fixture() +def expired_access_token_data_without_refresh(): + return { + "access_token": "access_token", + "token_type": "Bearer", + "expires_in": -1, + } + + +@pytest.fixture() +def expired_access_token_data_with_refresh(expired_access_token_data_without_refresh): + return { + **expired_access_token_data_without_refresh, + "refresh_token": "refresh_token", + } + + +def test_get_token_first_call( + dash_app_and_auth, mocker, access_token_data_with_refresh +): + dash_app, auth = dash_app_and_auth + + mocker.patch( + "dash_auth_external.auth._get_token_data_from_session", + return_value=access_token_data_with_refresh, + ) + + @dash_app.callback(Output("test-output", "children"), Input("test-input", "value")) + def test_callback(value): + token = auth.get_token() + return token + + assert auth.token_data is None + test_callback("test") + assert isinstance(auth.token_data, OAuth2Token) + + +def test_get_token_second_call( + dash_app_and_auth, mocker, access_token_data_with_refresh +): + dash_app, auth = dash_app_and_auth + + mocker.patch( + "dash_auth_external.auth._get_token_data_from_session", + return_value=access_token_data_with_refresh, + ) + + @dash_app.callback(Output("test-output", "children"), Input("test-input", "value")) + def test_callback(value): + token = auth.get_token() + return token + + assert auth.token_data is None + test_callback("test") + assert isinstance(auth.token_data, OAuth2Token) + test_callback("test") + assert isinstance(auth.token_data, OAuth2Token) + + +def test_get_token_with_refresh( + dash_app_and_auth, + mocker, + expired_access_token_data_with_refresh, + access_token_data_with_refresh, +): + dash_app, auth = dash_app_and_auth + + refresh_mock = mocker.patch( + "dash_auth_external.auth.refresh_token", + return_value=OAuth2Token(**access_token_data_with_refresh), + ) + + @dash_app.callback(Output("test-output", "children"), Input("test-input", "value")) + def test_callback(value): + token = auth.get_token() + return token + + auth.token_data = OAuth2Token(**expired_access_token_data_with_refresh) + + test_callback("test") + assert isinstance(auth.token_data, OAuth2Token) + assert auth.token_data.expires_in > 0 + refresh_mock.assert_called_once() + + +def test_expired_token_raises_exception( + dash_app_and_auth, mocker, expired_access_token_data_without_refresh +): + dash_app, auth = dash_app_and_auth + + mocker.patch( + "dash_auth_external.auth._get_token_data_from_session", + return_value=expired_access_token_data_without_refresh, + ) + + @dash_app.callback(Output("test-output", "children"), Input("test-input", "value")) + def test_callback(value): + token = auth.get_token() + return token + + auth.token_data = OAuth2Token(**expired_access_token_data_without_refresh) + + with pytest.raises(TokenExpiredError): + test_callback("test") diff --git a/tests/test_routes.py b/tests/test_routes.py deleted file mode 100644 index 910d74d..0000000 --- a/tests/test_routes.py +++ /dev/null @@ -1,39 +0,0 @@ -from .test_context import dash_auth_external -from unittest import mock -import pytest -import requests -from werkzeug.wrappers import request -from dash_auth_external.auth import DashAuthExternal -from dash_auth_external.routes import token_request -import flask - -"""Module for unit tests -""" - - -@mock.patch("dash_auth_external.routes.requests.post") -def test_token_route_ok(mock_post): - - mock_post.return_value.status_code = 200 - response = token_request("Fakeurl", dict(), dict()) - assert response.status_code == 200 - - -@mock.patch("dash_auth_external.routes.requests.post") -def test_token_route_raises(mock_post): - mock_post.return_value.status_code = 400 - - with pytest.raises(requests.RequestException): - response = token_request("Fakeurl", dict(), dict()) - - -def test_auth_route_ok(): - pass - - -def test_auth_route_raises(): - pass - - -def test_make_token_body(): - pass