Skip to content

Commit

Permalink
✨ Authentication mechanism (#137)
Browse files Browse the repository at this point in the history
  • Loading branch information
perdy authored and migduroli committed Sep 3, 2024
1 parent d877f19 commit 4899473
Show file tree
Hide file tree
Showing 17 changed files with 1,165 additions and 26 deletions.
5 changes: 4 additions & 1 deletion flama/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,10 @@ def resolve(self, headers: types.Headers) -> types.Cookies:
cookie = SimpleCookie()
cookie.load(headers.get("cookie", ""))
return types.Cookies(
{str(name): {str(k): str(v) for k, v in morsel.items()} for name, morsel in cookie.items()}
{
str(name): {**{str(k): str(v) for k, v in morsel.items()}, "value": morsel.value}
for name, morsel in cookie.items()
}
)


Expand Down
21 changes: 0 additions & 21 deletions flama/authentication.py

This file was deleted.

3 changes: 3 additions & 0 deletions flama/authentication/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from flama.authentication.components import * # noqa
from flama.authentication.jwt import * # noqa
from flama.authentication.middlewares import * # noqa
79 changes: 79 additions & 0 deletions flama/authentication/components.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import http
import logging

from flama import Component
from flama.authentication import exceptions, jwt
from flama.exceptions import HTTPException
from flama.types import Headers
from flama.types.http import Cookies

logger = logging.getLogger(__name__)

__all__ = ["JWTComponent"]


class JWTComponent(Component):
def __init__(
self,
secret: bytes,
*,
header_key: str = "Authorization",
header_prefix: str = "Bearer",
cookie_key: str = "flama_authentication",
):
self.secret = secret
self.header_key = header_key
self.header_prefix = header_prefix
self.cookie_key = cookie_key

def _token_from_cookies(self, cookies: Cookies) -> bytes:
try:
token = cookies[self.cookie_key]["value"]
except KeyError:
print(cookies)
logger.debug("'%s' not found in cookies", self.cookie_key)
raise exceptions.Unauthorized()

return token.encode()

def _token_from_header(self, headers: Headers) -> bytes:
try:
header_prefix, token = headers[self.header_key].split()
except KeyError:
logger.debug("'%s' not found in headers", self.header_key)
raise exceptions.Unauthorized()
except ValueError:
logger.debug("Wrong format for authorization header value")
raise exceptions.JWTException(
f"Authentication header must be '{self.header_key}: {self.header_prefix} <token>'"
)

if header_prefix != self.header_prefix:
logger.debug("Wrong prefix '%s' for authorization header, expected '%s'", header_prefix, self.header_prefix)
raise exceptions.JWTException(
f"Authentication header must be '{self.header_key}: {self.header_prefix} <token>'"
)

return token.encode()

def resolve(self, headers: Headers, cookies: Cookies) -> jwt.JWT:
try:
try:
encoded_token = self._token_from_header(headers)
except exceptions.Unauthorized:
encoded_token = self._token_from_cookies(cookies)
except exceptions.Unauthorized:
raise HTTPException(status_code=http.HTTPStatus.UNAUTHORIZED)
except exceptions.JWTException as e:
raise HTTPException(
status_code=http.HTTPStatus.BAD_REQUEST, detail={"error": e.__class__, "description": str(e)}
)

try:
token = jwt.JWT.decode(encoded_token, self.secret)
except (exceptions.JWTDecodeException, exceptions.JWTValidateException) as e:
raise HTTPException(
status_code=http.HTTPStatus.BAD_REQUEST, detail={"error": e.__class__, "description": str(e)}
)

return token
38 changes: 38 additions & 0 deletions flama/authentication/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
__all__ = [
"AuthenticationException",
"Unauthorized",
"JWTException",
"JWTDecodeException",
"JWTValidateException",
"JWTClaimValidateException",
]


class AuthenticationException(Exception):
...


class Unauthorized(AuthenticationException):
...


class Forbidden(AuthenticationException):
...


class JWTException(AuthenticationException):
...


class JWTDecodeException(JWTException):
...


class JWTValidateException(JWTException):
...


class JWTClaimValidateException(JWTValidateException):
def __init__(self, claim: str) -> None:
self.claim = claim
super().__init__(f"Claim '{self.claim}' is not valid")
3 changes: 3 additions & 0 deletions flama/authentication/jwt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from flama.authentication.jwt.jwt import JWT

__all__ = ["JWT"]
55 changes: 55 additions & 0 deletions flama/authentication/jwt/algorithms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import abc
import hmac

__all__ = ["SignAlgorithm", "HMACAlgorithm"]


class SignAlgorithm(abc.ABC):
"""Abstract class for signature algorithms."""

@abc.abstractmethod
def sign(self, message: bytes, key: bytes) -> bytes:
"""Sign a message using the given key.
:param message: Message to sign.
:param key: Key used to sign the message.
:return: Signature.
"""
...

@abc.abstractmethod
def verify(self, message: bytes, signature: bytes, key) -> bool:
"""Verify the signature of a message.
:param message: Message to verify.
:param signature: Signed message.
:param key: Key used to sign the message.
:return: True if the signature is valid, False otherwise.
"""
...


class HMACAlgorithm(SignAlgorithm):
"""HMAC using SHA algorithms for JWS."""

def __init__(self, sha):
self.hash_algorithm = sha

def sign(self, message: bytes, key: bytes) -> bytes:
"""Sign a message using the given key.
:param message: Message to sign.
:param key: Key used to sign the message.
:return: Signature.
"""
return hmac.new(key, message, self.hash_algorithm).digest()

def verify(self, message: bytes, signature: bytes, key) -> bool:
"""Verify the signature of a message.
:param message: Message to verify.
:param signature: Signed message.
:param key: Key used to sign the message.
:return: True if the signature is valid, False otherwise.
"""
return hmac.compare_digest(signature, hmac.new(key, message, self.hash_algorithm).digest())
150 changes: 150 additions & 0 deletions flama/authentication/jwt/claims.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import abc
import time
import typing as t

from flama.authentication import exceptions

if t.TYPE_CHECKING:
from flama.authentication.jwt.jwt import Payload

__all__ = [
"ClaimValidator",
"IssValidator",
"SubValidator",
"AudValidator",
"ExpValidator",
"NbfValidator",
"IatValidator",
"JtiValidator",
]


class ClaimValidator(abc.ABC):
claim: t.ClassVar[str]

def __init__(self, payload: "Payload", claims: t.Dict[str, t.Any]) -> None:
self.value = claims.get(self.claim)
self.payload = payload

@abc.abstractmethod
def validate(self):
"""Validate the claim.
:raises JWTClaimValidateException: if the claim is not valid.
"""
...


class IssValidator(ClaimValidator):
"""Issuer claim validator."""

claim = "iss"

def validate(self):
"""Validate the claim.
:raises JWTClaimValidateException: if the claim is not valid.
"""
...


class SubValidator(ClaimValidator):
"""Subject claim validator."""

claim = "sub"

def validate(self):
"""Validate the claim.
:raises JWTClaimValidateException: if the claim is not valid.
"""
...


class AudValidator(ClaimValidator):
"""Audience claim validator."""

claim = "aud"

def validate(self):
"""Validate the claim.
:raises JWTClaimValidateException: if the claim is not valid.
"""
...


class ExpValidator(ClaimValidator):
"""Expiration time claim validator.
The value of the claim must be a number representing the expiration time of the token in seconds since the epoch
(UTC). The expiration time must be after the current time.
"""

claim = "exp"

def validate(self):
"""Validate the claim.
The value of the claim must be a number representing the expiration time of the token in seconds since the
epoch (UTC). The expiration time must be after the current time.
:raises JWTClaimValidateException: if the claim is not valid.
"""
if self.payload.exp is not None and self.payload.exp < int(time.time()):
raise exceptions.JWTClaimValidateException("exp")


class NbfValidator(ClaimValidator):
"""Not before claim validator.
The value of the claim must be a number representing the time before which the token must not be accepted for
processing in seconds since the epoch (UTC). The time must be before the current time.
"""

claim = "nbf"

def validate(self):
"""Validate the claim.
The value of the claim must be a number representing the time before which the token must not be accepted for
processing in seconds since the epoch (UTC). The time must be before the current time.
:raises JWTClaimValidateException: if the claim is not valid.
"""
if self.payload.nbf is not None and self.payload.nbf > int(time.time()):
raise exceptions.JWTClaimValidateException("nbf")


class IatValidator(ClaimValidator):
"""Issued at claim validator.
The value of the claim must be a number representing the time at which the JWT was issued in seconds since the
epoch (UTC). The time must be before the current time.
"""

claim = "iat"

def validate(self):
"""Validate the claim.
The value of the claim must be a number representing the time at which the JWT was issued in seconds since the
epoch (UTC). The time must be before the current time.
:raises JWTClaimValidateException: if the claim is not valid.
"""
if self.payload.iat is not None and self.payload.iat > int(time.time()):
raise exceptions.JWTClaimValidateException("iat")


class JtiValidator(ClaimValidator):
"""JWT ID claim validator."""

claim = "jti"

def validate(self):
"""Validate the claim.
:raises JWTClaimValidateException: if the claim is not valid.
"""
...
Loading

0 comments on commit 4899473

Please sign in to comment.