Skip to content

Commit

Permalink
Merge pull request #14 from jamesholcombe/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
jamesholcombe authored Feb 4, 2023
2 parents 1219c44 + 1670a7b commit 6f780b3
Show file tree
Hide file tree
Showing 11 changed files with 381 additions and 108 deletions.
19 changes: 11 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ That's it! You can now define your layout and callbacks as usual.
> **NOTE** This can **ONLY** be done in the context of a dash callback.
```python
...

app.layout = html.Div(
[
html.Div(id="example-output"),
Expand All @@ -51,12 +53,17 @@ Output("example-output", "children"),
Input("example-input", "value")
)
def example_callback(value):
token = (
auth.get_token()
) ##The token can only be retrieved in the context of a dash callback
token = auth.get_token()
##The token can only be retrieved in the context of a dash callback
return token
```

## Refresh Tokens

If your OAuth provider supports refresh tokens, these are automatically checked and handled in the _get_token_ method.

> Check if your OAuth provider requires any additional scopes to support refresh tokens
## Troubleshooting

If you hit 400 responses (bad request) from either endpoint, there are a number of things that might need configuration.
Expand All @@ -67,10 +74,6 @@ Make sure you have checked the following

_The library uses a default redirect URI of http://127.0.0.1:8050/redirect_.

- Check the **key field** for the **token** in the JSON response returned by the token endpoint by your OAuth provider.

_The default is "access_token" but different OAuth providers may use a different key for this._

## Contributing

Contributions, issues, and ideas are all more than welcome
Contributions, issues, and ideas are all more than welcome.
101 changes: 87 additions & 14 deletions dash_auth_external/auth.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,24 @@
from flask import Flask
import flask
from werkzeug.routing import RoutingException, ValidationError
from .routes import make_access_token_route, make_auth_route
from .routes import make_access_token_route, make_auth_route, token_request
from urllib.parse import urljoin
import os
from dash_auth_external.token import OAuth2Token
from dash_auth_external.config import FLASK_SESSION_TOKEN_KEY
from dash_auth_external.exceptions import TokenExpiredError

from flask import session


def _get_token_data_from_session() -> dict:
"""Gets the token data from the session.
Returns:
dict: The token data from the session.
"""
token_data = session.get(FLASK_SESSION_TOKEN_KEY)
if token_data is None:
raise ValueError("No token found in request session.")
return token_data


class DashAuthExternal:
Expand All @@ -17,17 +32,36 @@ def generate_secret_key(length: int = 24) -> str:
return os.urandom(length)

def get_token(self) -> str:
"""Retrieves the access token from flask request headers, using the token cookie given on __init__.
"""Attempts to get a valid access token.
Returns:
str: Bearer Access token from your OAuth2 Provider
"""
token = flask.request.headers.get(self._token_field_name)
if token is None:
raise KeyError(
f"Header with name {self._token_field_name} not found in the flask request headers."

if self.token_data is not None:
if not self.token_data.is_expired():
return self.token_data.access_token

if not self.token_data.refresh_token:
raise TokenExpiredError(
"Token is expired and no refresh token available to refresh token."
)

self.token_data = refresh_token(
self.external_token_url, self.token_data, self.token_request_headers
)
return token
return self.token_data.access_token

token_data = _get_token_data_from_session()

self.token_data = OAuth2Token(
access_token=token_data["access_token"],
refresh_token=token_data.get("refresh_token"),
expires_in=token_data.get("expires_in"),
token_type=token_data.get("token_type"),
)

return self.token_data.access_token

def __init__(
self,
Expand All @@ -39,13 +73,13 @@ def __init__(
redirect_suffix: str = "/redirect",
auth_suffix: str = "/",
home_suffix="/home",
_flask_server: Flask = None,
_token_field_name: str = "access_token",
_secret_key: str = None,
auth_request_headers: dict = None,
token_request_headers: dict = None,
scope: str = None,
_server_name: str = __name__,
_static_folder: str = './assets/'
):
"""The interface for obtaining access tokens from 3rd party OAuth2 Providers.
Expand All @@ -58,18 +92,32 @@ def __init__(
redirect_suffix (str, optional): The route that OAuth2 provider will redirect back to. Defaults to "/redirect".
auth_suffix (str, optional): The route that will trigger the initial redirect to the external OAuth provider. Defaults to "/".
home_suffix (str, optional): The route your dash application will sit, relative to your url. Defaults to "/home".
_flask_server (Flask, optional): Flask server to use if additional config required. Defaults to None.
_token_field_name (str, optional): The key for the token returned in JSON from the token endpoint. Defaults to "access_token".
_secret_key (str, optional): Secret key for flask app, normally generated at runtime. Defaults to None.
auth_request_params (dict, optional): Additional params to send to the authorization endpoint. Defaults to None.
auth_request_headers (dict, optional): Additional params to send to the authorization endpoint. Defaults to None.
token_request_headers (dict, optional): Additional headers to send to the access token endpoint. Defaults to None.
scope (str, optional): Header required by most Oauth2 Providers. Defaults to None.
_server_name (str, optional): The name of the Flask Server. Defaults to __name__, so the name of this library.
_static_folder (str, optional): The folder with static assets. Defaults to "./assets/".
_server_name (str, optional): The name of the Flask Server. Defaults to __name__, ignored if _flask_server is not None.
Returns:
DashAuthExternal: Main package class
"""
app = Flask(_server_name, instance_relative_config=False, static_folder=_static_folder)

self.token_data: OAuth2Token = None
if auth_request_headers is None:
auth_request_headers = {}
if token_request_headers is None:
token_request_headers = {}

if _flask_server is None:

app = Flask(
_server_name, instance_relative_config=False, static_folder="./assets"
)
else:
app = _flask_server

if _secret_key is None:
app.secret_key = self.generate_secret_key()
Expand Down Expand Up @@ -97,10 +145,35 @@ def __init__(
_home_suffix=home_suffix,
token_request_headers=token_request_headers,
_token_field_name=_token_field_name,
with_pkce=with_pkce,
)

self.server = app
self.home_suffix = home_suffix
self.redirect_suffix = redirect_suffix
self.auth_suffix = auth_suffix
self._token_field_name = _token_field_name
self.client_id = client_id
self.external_token_url = external_token_url
self.token_request_headers = token_request_headers
self.scope = scope


def refresh_token(url: str, token_data: OAuth2Token, headers: dict) -> OAuth2Token:

body = {
"grant_type": "refresh_token",
"refresh_token": token_data.refresh_token,
}
data = token_request(url, body, headers)

token_data.access_token = data["access_token"]

# If the provider does not return a new refresh token, use the old one.
if "refresh_token" in data:
token_data.refresh_token = data["refresh_token"]

if "expires_in" in data:
token_data.expires_in = data["expires_in"]

return token_data
1 change: 1 addition & 0 deletions dash_auth_external/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
FLASK_SESSION_TOKEN_KEY = "TokenDataDashAuthExternal"
4 changes: 4 additions & 0 deletions dash_auth_external/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
class TokenExpiredError(Exception):
"""Exception raised when an expired token is encountered."""

pass
35 changes: 14 additions & 21 deletions dash_auth_external/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
import requests
import hashlib
from requests_oauthlib import OAuth2Session
from dash_auth_external.config import FLASK_SESSION_TOKEN_KEY

os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "1"


def make_code_challenge(length: int = 40):
code_verifier = base64.urlsafe_b64encode(
os.urandom(length)).decode("utf-8")
code_verifier = base64.urlsafe_b64encode(os.urandom(length)).decode("utf-8")
code_verifier = re.sub("[^a-zA-Z0-9]+", "", code_verifier)
code_challenge = hashlib.sha256(code_verifier.encode("utf-8")).digest()
code_challenge = base64.urlsafe_b64encode(code_challenge).decode("utf-8")
Expand All @@ -27,9 +27,9 @@ def make_auth_route(
client_id: str,
auth_suffix: str,
redirect_uri: str,
with_pkce: bool = True,
scope: str = None,
auth_request_params: dict = None,
with_pkce: bool,
scope: str,
auth_request_params: dict,
):
@app.route(auth_suffix)
def get_auth_code():
Expand Down Expand Up @@ -64,9 +64,7 @@ def get_auth_code():
return app


def build_token_body(
url: str, redirect_uri: str, client_id: str, with_pkce: bool, client_secret: str
):
def build_token_body(url: str, redirect_uri: str, client_id: str, with_pkce: bool):
query = urllib.parse.urlparse(url).query
redirect_params = urllib.parse.parse_qs(query)
code = redirect_params["code"][0]
Expand Down Expand Up @@ -94,27 +92,27 @@ def make_access_token_route(
redirect_uri: str,
client_id: str,
_token_field_name: str,
with_pkce: bool = True,
token_request_headers: dict = None,
with_pkce: bool,
token_request_headers: dict,
):
@app.route(redirect_suffix, methods=["GET", "POST"])
def get_token():
def get_token_route():
url = request.url
body = build_token_body(
url=url,
redirect_uri=redirect_uri,
with_pkce=with_pkce,
client_id=client_id,

)

response_data = get_token_response_data(
external_token_url, body, token_request_headers
response_data = token_request(
url=external_token_url,
body=body,
headers=token_request_headers,
)
token = response_data[_token_field_name]

response = redirect(_home_suffix)
response.headers.add(_token_field_name, token)
session[FLASK_SESSION_TOKEN_KEY] = response_data
return response

return app
Expand All @@ -127,9 +125,4 @@ def token_request(url: str, body: dict, headers: dict):
raise requests.RequestException(
f"{r.status_code} {r.reason}:The request to the access token endpoint failed."
)
return r


def get_token_response_data(*args):
r = token_request(*args)
return r.json()
16 changes: 16 additions & 0 deletions dash_auth_external/token.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from dataclasses import dataclass
import time


@dataclass
class OAuth2Token:
access_token: str
token_type: str
expires_in: int
refresh_token: str = None

def __post_init__(self):
self.expires_at = time.time() + self.expires_in

def is_expired(self):
return time.time() > self.expires_at if self.expires_at else False
4 changes: 2 additions & 2 deletions examples/usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@

app = Dash(__name__, server=server) # instantiating our app using this server

##Below we can define our dash app like normal
# Below we can define our dash app like normal
app.layout = html.Div([html.Div(id="example-output"), dcc.Input(id="example-input")])


@app.callback(Output("example-output", "children"), Input("example-input", "value"))
def example_callback(value):
token = (
auth.get_token()
) ##The token can only be retrieved in the context of a dash callback
) # The token can only be retrieved in the context of a dash callback
return token


Expand Down
10 changes: 4 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@
NAME = "dash-auth-external"

this_directory = Path(__file__).parent
long_description = "Integrate your dashboards with 3rd party APIs and external OAuth providers."
requires = [
"dash >= 2.0.0",
"requests >= 1.0.0",
"requests-oauthlib >= 0.3.0"
]
long_description = (
"Integrate your dashboards with 3rd party APIs and external OAuth providers."
)
requires = ["dash >= 2.0.0", "requests >= 1.0.0", "requests-oauthlib >= 0.3.0"]

setup(
name=NAME,
Expand Down
Loading

0 comments on commit 6f780b3

Please sign in to comment.