Skip to content

Commit

Permalink
Merge pull request #17 from jamesholcombe/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
jamesholcombe authored Apr 6, 2023
2 parents 6d9e09e + 98eacdb commit 16757ce
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 57 deletions.
36 changes: 16 additions & 20 deletions dash_auth_external/auth.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dataclasses import asdict
from flask import Flask
from .routes import make_access_token_route, make_auth_route, token_request
from urllib.parse import urljoin
Expand All @@ -21,6 +22,10 @@ def _get_token_data_from_session() -> dict:
return token_data


def _set_token_data_in_session(token: OAuth2Token):
session[FLASK_SESSION_TOKEN_KEY] = token


class DashAuthExternal:
@staticmethod
def generate_secret_key(length: int = 24) -> str:
Expand All @@ -37,31 +42,23 @@ def get_token(self) -> str:
Returns:
str: Bearer Access token from your OAuth2 Provider
"""
token_data = _get_token_data_from_session()

if self.token_data is not None:
if not self.token_data.is_expired():
return self.token_data.access_token
token = OAuth2Token(**token_data)

if not self.token_data.refresh_token:
raise TokenExpiredError(
"Token is expired and no refresh token available to refresh token."
)
if not token.is_expired():
return token.access_token

self.token_data = refresh_token(
self.external_token_url, self.token_data, self.token_request_headers
if not token.refresh_token:
raise TokenExpiredError(
"Token is expired and no refresh token available to refresh token."
)
return self.token_data.access_token

token_data = _get_token_data_from_session()

self.token_data = OAuth2Token(
access_token=token_data["access_token"],
refresh_token=token_data.get("refresh_token"),
expires_in=token_data.get("expires_in"),
token_type=token_data.get("token_type"),
token_data = refresh_token(
self.external_token_url, token_data, self.token_request_headers
)

return self.token_data.access_token
_set_token_data_in_session(token_data)
return token_data.access_token

def __init__(
self,
Expand Down Expand Up @@ -105,7 +102,6 @@ def __init__(
DashAuthExternal: Main package class
"""

self.token_data: OAuth2Token = None
if auth_request_headers is None:
auth_request_headers = {}
if token_request_headers is None:
Expand Down
6 changes: 5 additions & 1 deletion dash_auth_external/routes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dataclasses import asdict
from flask import session, redirect, request
import os
import base64
Expand All @@ -8,6 +9,7 @@
import hashlib
from requests_oauthlib import OAuth2Session
from dash_auth_external.config import FLASK_SESSION_TOKEN_KEY
from dash_auth_external.token import OAuth2Token

os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "1"

Expand Down Expand Up @@ -112,7 +114,9 @@ def get_token_route():
)

response = redirect(_home_suffix)
session[FLASK_SESSION_TOKEN_KEY] = response_data

session[FLASK_SESSION_TOKEN_KEY] = asdict(OAuth2Token(**response_data))

return response

return app
Expand Down
8 changes: 5 additions & 3 deletions dash_auth_external/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
@dataclass
class OAuth2Token:
access_token: str
token_type: str
expires_in: int
token_type: str = None
expires_in: int = None
refresh_token: str = None
expires_at: float = None

def __post_init__(self):
self.expires_at = time.time() + self.expires_in
if self.expires_at is None and self.expires_in is not None:
self.expires_at = time.time() + float(self.expires_in)

def is_expired(self):
return time.time() > self.expires_at if self.expires_at else False
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
name=NAME,
version="0.3.0",
description="Integrate Dash with 3rd Parties and external providers",
python_requires=">3.7",
author_email="[email protected]",
url="https://github.com/jamesholcombe/dash-auth-external",
keywords=["Dash", "Plotly", "Authentication", "Auth", "External"],
Expand Down
47 changes: 14 additions & 33 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,56 +61,39 @@ def test_callback(value):
token = auth.get_token()
return token

assert auth.token_data is None
test_callback("test")
assert isinstance(auth.token_data, OAuth2Token)


def test_get_token_second_call(
dash_app_and_auth, mocker, access_token_data_with_refresh
):
dash_app, auth = dash_app_and_auth

mocker.patch(
"dash_auth_external.auth._get_token_data_from_session",
return_value=access_token_data_with_refresh,
)

@dash_app.callback(Output("test-output", "children"), Input("test-input", "value"))
def test_callback(value):
token = auth.get_token()
return token

assert auth.token_data is None
test_callback("test")
assert isinstance(auth.token_data, OAuth2Token)
test_callback("test")
assert isinstance(auth.token_data, OAuth2Token)
assert test_callback("test") == "access_token"


def test_get_token_with_refresh(
dash_app_and_auth,
mocker,
expired_access_token_data_with_refresh,
access_token_data_with_refresh,
):
dash_app, auth = dash_app_and_auth

mocker.patch(
"dash_auth_external.auth._get_token_data_from_session",
return_value=expired_access_token_data_with_refresh,
)

refresh_mock = mocker.patch(
"dash_auth_external.auth.refresh_token",
return_value=OAuth2Token(**access_token_data_with_refresh),
return_value=OAuth2Token(**expired_access_token_data_with_refresh),
)

@dash_app.callback(Output("test-output", "children"), Input("test-input", "value"))
def test_callback(value):
token = auth.get_token()
return token

auth.token_data = OAuth2Token(**expired_access_token_data_with_refresh)
# mocking as working out of runtime context
mocker.patch(
"dash_auth_external.auth._set_token_data_in_session",
return_value=expired_access_token_data_with_refresh,
)

assert test_callback("test") == "access_token"

test_callback("test")
assert isinstance(auth.token_data, OAuth2Token)
assert auth.token_data.expires_in > 0
refresh_mock.assert_called_once()


Expand All @@ -129,7 +112,5 @@ def test_callback(value):
token = auth.get_token()
return token

auth.token_data = OAuth2Token(**expired_access_token_data_without_refresh)

with pytest.raises(TokenExpiredError):
test_callback("test")

0 comments on commit 16757ce

Please sign in to comment.