-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
17 changed files
with
1,165 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from flama.authentication.components import * # noqa | ||
from flama.authentication.jwt import * # noqa | ||
from flama.authentication.middlewares import * # noqa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import http | ||
import logging | ||
|
||
from flama import Component | ||
from flama.authentication import exceptions, jwt | ||
from flama.exceptions import HTTPException | ||
from flama.types import Headers | ||
from flama.types.http import Cookies | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
__all__ = ["JWTComponent"] | ||
|
||
|
||
class JWTComponent(Component): | ||
def __init__( | ||
self, | ||
secret: bytes, | ||
*, | ||
header_key: str = "Authorization", | ||
header_prefix: str = "Bearer", | ||
cookie_key: str = "flama_authentication", | ||
): | ||
self.secret = secret | ||
self.header_key = header_key | ||
self.header_prefix = header_prefix | ||
self.cookie_key = cookie_key | ||
|
||
def _token_from_cookies(self, cookies: Cookies) -> bytes: | ||
try: | ||
token = cookies[self.cookie_key]["value"] | ||
except KeyError: | ||
print(cookies) | ||
logger.debug("'%s' not found in cookies", self.cookie_key) | ||
raise exceptions.Unauthorized() | ||
|
||
return token.encode() | ||
|
||
def _token_from_header(self, headers: Headers) -> bytes: | ||
try: | ||
header_prefix, token = headers[self.header_key].split() | ||
except KeyError: | ||
logger.debug("'%s' not found in headers", self.header_key) | ||
raise exceptions.Unauthorized() | ||
except ValueError: | ||
logger.debug("Wrong format for authorization header value") | ||
raise exceptions.JWTException( | ||
f"Authentication header must be '{self.header_key}: {self.header_prefix} <token>'" | ||
) | ||
|
||
if header_prefix != self.header_prefix: | ||
logger.debug("Wrong prefix '%s' for authorization header, expected '%s'", header_prefix, self.header_prefix) | ||
raise exceptions.JWTException( | ||
f"Authentication header must be '{self.header_key}: {self.header_prefix} <token>'" | ||
) | ||
|
||
return token.encode() | ||
|
||
def resolve(self, headers: Headers, cookies: Cookies) -> jwt.JWT: | ||
try: | ||
try: | ||
encoded_token = self._token_from_header(headers) | ||
except exceptions.Unauthorized: | ||
encoded_token = self._token_from_cookies(cookies) | ||
except exceptions.Unauthorized: | ||
raise HTTPException(status_code=http.HTTPStatus.UNAUTHORIZED) | ||
except exceptions.JWTException as e: | ||
raise HTTPException( | ||
status_code=http.HTTPStatus.BAD_REQUEST, detail={"error": e.__class__, "description": str(e)} | ||
) | ||
|
||
try: | ||
token = jwt.JWT.decode(encoded_token, self.secret) | ||
except (exceptions.JWTDecodeException, exceptions.JWTValidateException) as e: | ||
raise HTTPException( | ||
status_code=http.HTTPStatus.BAD_REQUEST, detail={"error": e.__class__, "description": str(e)} | ||
) | ||
|
||
return token |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
__all__ = [ | ||
"AuthenticationException", | ||
"Unauthorized", | ||
"JWTException", | ||
"JWTDecodeException", | ||
"JWTValidateException", | ||
"JWTClaimValidateException", | ||
] | ||
|
||
|
||
class AuthenticationException(Exception): | ||
... | ||
|
||
|
||
class Unauthorized(AuthenticationException): | ||
... | ||
|
||
|
||
class Forbidden(AuthenticationException): | ||
... | ||
|
||
|
||
class JWTException(AuthenticationException): | ||
... | ||
|
||
|
||
class JWTDecodeException(JWTException): | ||
... | ||
|
||
|
||
class JWTValidateException(JWTException): | ||
... | ||
|
||
|
||
class JWTClaimValidateException(JWTValidateException): | ||
def __init__(self, claim: str) -> None: | ||
self.claim = claim | ||
super().__init__(f"Claim '{self.claim}' is not valid") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from flama.authentication.jwt.jwt import JWT | ||
|
||
__all__ = ["JWT"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import abc | ||
import hmac | ||
|
||
__all__ = ["SignAlgorithm", "HMACAlgorithm"] | ||
|
||
|
||
class SignAlgorithm(abc.ABC): | ||
"""Abstract class for signature algorithms.""" | ||
|
||
@abc.abstractmethod | ||
def sign(self, message: bytes, key: bytes) -> bytes: | ||
"""Sign a message using the given key. | ||
:param message: Message to sign. | ||
:param key: Key used to sign the message. | ||
:return: Signature. | ||
""" | ||
... | ||
|
||
@abc.abstractmethod | ||
def verify(self, message: bytes, signature: bytes, key) -> bool: | ||
"""Verify the signature of a message. | ||
:param message: Message to verify. | ||
:param signature: Signed message. | ||
:param key: Key used to sign the message. | ||
:return: True if the signature is valid, False otherwise. | ||
""" | ||
... | ||
|
||
|
||
class HMACAlgorithm(SignAlgorithm): | ||
"""HMAC using SHA algorithms for JWS.""" | ||
|
||
def __init__(self, sha): | ||
self.hash_algorithm = sha | ||
|
||
def sign(self, message: bytes, key: bytes) -> bytes: | ||
"""Sign a message using the given key. | ||
:param message: Message to sign. | ||
:param key: Key used to sign the message. | ||
:return: Signature. | ||
""" | ||
return hmac.new(key, message, self.hash_algorithm).digest() | ||
|
||
def verify(self, message: bytes, signature: bytes, key) -> bool: | ||
"""Verify the signature of a message. | ||
:param message: Message to verify. | ||
:param signature: Signed message. | ||
:param key: Key used to sign the message. | ||
:return: True if the signature is valid, False otherwise. | ||
""" | ||
return hmac.compare_digest(signature, hmac.new(key, message, self.hash_algorithm).digest()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
import abc | ||
import time | ||
import typing as t | ||
|
||
from flama.authentication import exceptions | ||
|
||
if t.TYPE_CHECKING: | ||
from flama.authentication.jwt.jwt import Payload | ||
|
||
__all__ = [ | ||
"ClaimValidator", | ||
"IssValidator", | ||
"SubValidator", | ||
"AudValidator", | ||
"ExpValidator", | ||
"NbfValidator", | ||
"IatValidator", | ||
"JtiValidator", | ||
] | ||
|
||
|
||
class ClaimValidator(abc.ABC): | ||
claim: t.ClassVar[str] | ||
|
||
def __init__(self, payload: "Payload", claims: t.Dict[str, t.Any]) -> None: | ||
self.value = claims.get(self.claim) | ||
self.payload = payload | ||
|
||
@abc.abstractmethod | ||
def validate(self): | ||
"""Validate the claim. | ||
:raises JWTClaimValidateException: if the claim is not valid. | ||
""" | ||
... | ||
|
||
|
||
class IssValidator(ClaimValidator): | ||
"""Issuer claim validator.""" | ||
|
||
claim = "iss" | ||
|
||
def validate(self): | ||
"""Validate the claim. | ||
:raises JWTClaimValidateException: if the claim is not valid. | ||
""" | ||
... | ||
|
||
|
||
class SubValidator(ClaimValidator): | ||
"""Subject claim validator.""" | ||
|
||
claim = "sub" | ||
|
||
def validate(self): | ||
"""Validate the claim. | ||
:raises JWTClaimValidateException: if the claim is not valid. | ||
""" | ||
... | ||
|
||
|
||
class AudValidator(ClaimValidator): | ||
"""Audience claim validator.""" | ||
|
||
claim = "aud" | ||
|
||
def validate(self): | ||
"""Validate the claim. | ||
:raises JWTClaimValidateException: if the claim is not valid. | ||
""" | ||
... | ||
|
||
|
||
class ExpValidator(ClaimValidator): | ||
"""Expiration time claim validator. | ||
The value of the claim must be a number representing the expiration time of the token in seconds since the epoch | ||
(UTC). The expiration time must be after the current time. | ||
""" | ||
|
||
claim = "exp" | ||
|
||
def validate(self): | ||
"""Validate the claim. | ||
The value of the claim must be a number representing the expiration time of the token in seconds since the | ||
epoch (UTC). The expiration time must be after the current time. | ||
:raises JWTClaimValidateException: if the claim is not valid. | ||
""" | ||
if self.payload.exp is not None and self.payload.exp < int(time.time()): | ||
raise exceptions.JWTClaimValidateException("exp") | ||
|
||
|
||
class NbfValidator(ClaimValidator): | ||
"""Not before claim validator. | ||
The value of the claim must be a number representing the time before which the token must not be accepted for | ||
processing in seconds since the epoch (UTC). The time must be before the current time. | ||
""" | ||
|
||
claim = "nbf" | ||
|
||
def validate(self): | ||
"""Validate the claim. | ||
The value of the claim must be a number representing the time before which the token must not be accepted for | ||
processing in seconds since the epoch (UTC). The time must be before the current time. | ||
:raises JWTClaimValidateException: if the claim is not valid. | ||
""" | ||
if self.payload.nbf is not None and self.payload.nbf > int(time.time()): | ||
raise exceptions.JWTClaimValidateException("nbf") | ||
|
||
|
||
class IatValidator(ClaimValidator): | ||
"""Issued at claim validator. | ||
The value of the claim must be a number representing the time at which the JWT was issued in seconds since the | ||
epoch (UTC). The time must be before the current time. | ||
""" | ||
|
||
claim = "iat" | ||
|
||
def validate(self): | ||
"""Validate the claim. | ||
The value of the claim must be a number representing the time at which the JWT was issued in seconds since the | ||
epoch (UTC). The time must be before the current time. | ||
:raises JWTClaimValidateException: if the claim is not valid. | ||
""" | ||
if self.payload.iat is not None and self.payload.iat > int(time.time()): | ||
raise exceptions.JWTClaimValidateException("iat") | ||
|
||
|
||
class JtiValidator(ClaimValidator): | ||
"""JWT ID claim validator.""" | ||
|
||
claim = "jti" | ||
|
||
def validate(self): | ||
"""Validate the claim. | ||
:raises JWTClaimValidateException: if the claim is not valid. | ||
""" | ||
... |
Oops, something went wrong.