Skip to content

Commit

Permalink
✨ Access and refresh token components
Browse files Browse the repository at this point in the history
  • Loading branch information
migduroli committed Oct 24, 2024
1 parent e0db83b commit 0b40328
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 50 deletions.
50 changes: 38 additions & 12 deletions flama/authentication/components.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,27 @@
import http
import logging
import typing as t

from flama import Component
from flama.authentication import exceptions, jwt
from flama.authentication import exceptions, jwt, types
from flama.exceptions import HTTPException
from flama.types import Headers
from flama.types.http import Cookies

logger = logging.getLogger(__name__)

__all__ = ["JWTComponent"]
__all__ = ["AccessTokenComponent", "RefreshTokenComponent"]


class JWTComponent(Component):
def __init__(
self,
secret: bytes,
*,
header_key: str = "Authorization",
header_prefix: str = "Bearer",
cookie_key: str = "flama_authentication",
):
class BaseTokenComponent(Component):
def __init__(self, secret: bytes, *, header_key: str, header_prefix: str, cookie_key: str):
self.secret = secret
self.header_key = header_key
self.header_prefix = header_prefix
self.cookie_key = cookie_key

def _token_from_cookies(self, cookies: Cookies) -> bytes:
print(f"ERROR: {cookies}")
try:
token = cookies[self.cookie_key]["value"]
except KeyError:
Expand All @@ -36,6 +31,7 @@ def _token_from_cookies(self, cookies: Cookies) -> bytes:
return token.encode()

def _token_from_header(self, headers: Headers) -> bytes:
print(f"ERROR: {headers}")
try:
header_prefix, token = headers[self.header_key].split()
except KeyError:
Expand All @@ -55,7 +51,7 @@ def _token_from_header(self, headers: Headers) -> bytes:

return token.encode()

def resolve(self, headers: Headers, cookies: Cookies) -> jwt.JWT:
def _resolve_token(self, headers: Headers, cookies: Cookies) -> jwt.JWT:
try:
try:
encoded_token = self._token_from_header(headers)
Expand All @@ -76,3 +72,33 @@ def resolve(self, headers: Headers, cookies: Cookies) -> jwt.JWT:
)

return token


class AccessTokenComponent(BaseTokenComponent):
def __init__(
self,
secret: bytes,
*,
header_prefix: str = "Bearer",
header_key: str = "access_token",
cookie_key: str = "access_token",
):
super().__init__(secret, header_prefix=header_prefix, header_key=header_key, cookie_key=cookie_key)

def resolve(self, headers: Headers, cookies: Cookies) -> types.AccessToken:
return t.cast(types.AccessToken, self._resolve_token(headers, cookies))


class RefreshTokenComponent(BaseTokenComponent):
def __init__(
self,
secret: bytes,
*,
header_prefix: str = "Bearer",
header_key: str = "refresh_token",
cookie_key: str = "refresh_token",
):
super().__init__(secret, header_prefix=header_prefix, header_key=header_key, cookie_key=cookie_key)

def resolve(self, headers: Headers, cookies: Cookies) -> types.RefreshToken:
return t.cast(types.RefreshToken, self._resolve_token(headers, cookies))
6 changes: 3 additions & 3 deletions flama/authentication/jwt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from flama.authentication.jwt.jwt import JWT

__all__ = ["JWT"]
from flama.authentication.components import * # noqa
from flama.authentication.jwt.jwt import JWT # noqa
from flama.authentication.types import * # noqa
6 changes: 4 additions & 2 deletions flama/authentication/middlewares.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import typing as t

from flama.authentication.jwt.jwt import JWT
from flama import authentication
from flama.exceptions import HTTPException
from flama.http import APIErrorResponse, Request

Expand Down Expand Up @@ -44,7 +44,9 @@ async def _get_response(self, scope: "types.Scope", receive: "types.Receive") ->
return self.app

try:
token: JWT = await app.injector.resolve(JWT).value({"request": Request(scope, receive=receive)})
token: authentication.AccessToken = await app.injector.resolve(authentication.AccessToken).value(
{"request": Request(scope, receive=receive)}
)
except HTTPException as e:
logger.debug("JWT error: %s", e.detail)
return APIErrorResponse(status_code=e.status_code, detail=e.detail)
Expand Down
8 changes: 8 additions & 0 deletions flama/authentication/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import typing as t

from flama.authentication.jwt import JWT

__all__ = ["AccessToken", "RefreshToken"]

AccessToken = t.NewType("AccessToken", JWT)
RefreshToken = t.NewType("RefreshToken", JWT)
136 changes: 107 additions & 29 deletions tests/authentication/test_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,57 +2,135 @@

import pytest

from flama import Flama
from flama.authentication.components import JWTComponent
from flama.authentication.jwt.jwt import JWT
from flama import Flama, authentication

TOKEN = (
b"eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9.eyJkYXRhIjogeyJmb28iOiAiYmFyIn0sICJpYXQiOiAwfQ==.J3zdedMZSFNOimstjJat0V"
b"28rM_b1UU62XCp9dg_5kg="
)


class TestCaseJWTComponent:
@pytest.fixture(scope="function")
def secret(self):
return uuid.UUID(int=0)
@pytest.fixture(scope="function")
def secret():
return uuid.UUID(int=0)

@pytest.fixture(scope="function")
def app(self, secret):
return Flama(
schema=None,
docs=None,
components=[
JWTComponent(
secret=secret.bytes,
header_key="Authorization",
header_prefix="Bearer",
cookie_key="flama_authentication",
)
],
)

@pytest.fixture(scope="function")
def app(secret):
return Flama(
schema=None,
docs=None,
components=[
authentication.AccessTokenComponent(secret=secret.bytes),
authentication.RefreshTokenComponent(secret=secret.bytes),
],
)


class TestCaseAccessTokenComponent:
@pytest.fixture(scope="function", autouse=True)
def add_endpoints(self, app):
@app.get("/")
def access_token(token: authentication.AccessToken):
return token.asdict()

@pytest.mark.parametrize(
["params", "status_code", "result"],
(
pytest.param(
{"headers": {"access_token": f"Bearer {TOKEN.decode()}"}},
200,
{"header": {"alg": "HS256", "typ": "JWT"}, "payload": {"data": {"foo": "bar"}, "iat": 0}},
id="headers",
),
pytest.param(
{"headers": {"access_token": "token"}},
400,
{
"detail": {
"description": "Authentication header must be 'access_token: Bearer <token>'",
"error": "JWTException",
},
"error": "HTTPException",
"status_code": 400,
},
id="header_wrong_format",
),
pytest.param(
{"headers": {"access_token": "Foo token"}},
400,
{
"detail": {
"description": "Authentication header must be 'access_token: Bearer <token>'",
"error": "JWTException",
},
"error": "HTTPException",
"status_code": 400,
},
id="header_wrong_prefix",
),
pytest.param(
{"cookies": {"access_token": TOKEN.decode()}},
200,
{"header": {"alg": "HS256", "typ": "JWT"}, "payload": {"data": {"foo": "bar"}, "iat": 0}},
id="cookies",
),
pytest.param(
{},
401,
{"detail": "Unauthorized", "error": "HTTPException", "status_code": 401},
id="unauthorized",
),
pytest.param(
{
"cookies": {
"access_token": "eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9.eyJkYXRhIjogeyJmb28iOiAiYmFyI"
"n0sICJpYXQiOiAwfQ==.0000",
}
},
401,
{
"detail": {
"description": "Signature verification failed for token 'eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVC"
"J9.eyJkYXRhIjogeyJmb28iOiAiYmFyIn0sICJpYXQiOiAwfQ==.0000'",
"error": "JWTValidateException",
},
"error": "HTTPException",
"status_code": 401,
},
id="invalid_token",
),
),
)
async def test_injection(self, client, status_code, result, params):
response = await client.request("get", "/", **params)

assert response.status_code == status_code
assert response.json() == result


class TestCaseRefreshTokenComponent:
@pytest.fixture(scope="function", autouse=True)
def add_endpoints(self, app):
@app.get("/")
def jwt(token: JWT):
def refresh_token(token: authentication.RefreshToken):
return token.asdict()

@pytest.mark.parametrize(
["params", "status_code", "result"],
(
pytest.param(
{"headers": {"Authorization": f"Bearer {TOKEN.decode()}"}},
{"headers": {"refresh_token": f"Bearer {TOKEN.decode()}"}},
200,
{"header": {"alg": "HS256", "typ": "JWT"}, "payload": {"data": {"foo": "bar"}, "iat": 0}},
id="headers",
),
pytest.param(
{"headers": {"Authorization": "token"}},
{"headers": {"refresh_token": "token"}},
400,
{
"detail": {
"description": "Authentication header must be 'Authorization: Bearer <token>'",
"description": "Authentication header must be 'refresh_token: Bearer <token>'",
"error": "JWTException",
},
"error": "HTTPException",
Expand All @@ -61,11 +139,11 @@ def jwt(token: JWT):
id="header_wrong_format",
),
pytest.param(
{"headers": {"Authorization": "Foo token"}},
{"headers": {"refresh_token": "Foo token"}},
400,
{
"detail": {
"description": "Authentication header must be 'Authorization: Bearer <token>'",
"description": "Authentication header must be 'refresh_token: Bearer <token>'",
"error": "JWTException",
},
"error": "HTTPException",
Expand All @@ -74,7 +152,7 @@ def jwt(token: JWT):
id="header_wrong_prefix",
),
pytest.param(
{"cookies": {"flama_authentication": TOKEN.decode()}},
{"cookies": {"refresh_token": TOKEN.decode()}},
200,
{"header": {"alg": "HS256", "typ": "JWT"}, "payload": {"data": {"foo": "bar"}, "iat": 0}},
id="cookies",
Expand All @@ -88,7 +166,7 @@ def jwt(token: JWT):
pytest.param(
{
"cookies": {
"flama_authentication": "eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9.eyJkYXRhIjogeyJmb28iOiAiYmFyI"
"refresh_token": "eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9.eyJkYXRhIjogeyJmb28iOiAiYmFyI"
"n0sICJpYXQiOiAwfQ==.0000",
}
},
Expand Down
8 changes: 4 additions & 4 deletions tests/authentication/test_middlewares.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest

from flama import Flama
from flama.authentication.components import JWTComponent
from flama.authentication.components import AccessTokenComponent
from flama.authentication.middlewares import AuthenticationMiddleware
from flama.middleware import Middleware

Expand All @@ -27,7 +27,7 @@ def app(self, secret):
return Flama(
schema=None,
docs=None,
components=[JWTComponent(secret=secret.bytes)],
components=[AccessTokenComponent(secret=secret.bytes)],
middleware=[Middleware(AuthenticationMiddleware)],
)

Expand All @@ -47,7 +47,7 @@ def headers(self, request):
return None

try:
return {"Authorization": f"Bearer {TOKENS[request.param].decode()}"}
return {"access_token": f"Bearer {TOKENS[request.param].decode()}"}
except KeyError:
raise ValueError(f"Invalid token {request.param}")

Expand All @@ -57,7 +57,7 @@ def cookies(self, request):
return None

try:
return {"flama_authentication": TOKENS[request.param].decode()}
return {"access_token": TOKENS[request.param].decode()}
except KeyError:
raise ValueError(f"Invalid token {request.param}")

Expand Down

0 comments on commit 0b40328

Please sign in to comment.