Skip to content

Commit

Permalink
🐛 Remove unintended sqlalchemy import (#143)
Browse files Browse the repository at this point in the history
  • Loading branch information
perdy authored and migduroli committed Sep 3, 2024
1 parent 036264e commit 2664aba
Show file tree
Hide file tree
Showing 35 changed files with 1,162 additions and 1,150 deletions.
18 changes: 15 additions & 3 deletions flama/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import typing as t

from flama import asgi, exceptions, http, injection, types, url, validation, websockets
from flama.ddd.components import WorkerComponent
from flama.events import Events
from flama.middleware import MiddlewareStack
from flama.models.modules import ModelsModule
Expand All @@ -14,6 +13,13 @@
from flama.routing import BaseRoute, Router
from flama.schemas.modules import SchemaModule

try:
from flama.ddd.components import WorkerComponent
from flama.resources.workers import FlamaWorker
except AssertionError:
WorkerComponent = None
FlamaWorker = None

if t.TYPE_CHECKING:
from flama.middleware import Middleware
from flama.modules import Module
Expand Down Expand Up @@ -83,16 +89,22 @@ def __init__(
}
)

# Create worker
worker = FlamaWorker() if FlamaWorker else None

# Initialize Modules
default_modules = [
ResourcesModule(),
ResourcesModule(worker=worker),
SchemaModule(title, version, description, schema=schema, docs=docs),
ModelsModule(),
]
self.modules = Modules(app=self, modules={*default_modules, *(modules or [])})

# Initialize router
default_components = [WorkerComponent(worker=default_modules[0].worker)]
default_components = []
if worker and WorkerComponent:
default_components.append(WorkerComponent(worker=worker))

self.app = self.router = Router(
routes=routes, components=[*default_components, *(components or [])], lifespan=lifespan
)
Expand Down
18 changes: 5 additions & 13 deletions flama/codecs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,21 @@

class Codec(metaclass=abc.ABCMeta):
@abc.abstractmethod
async def decode(self, item: t.Any, **options):
...

@abc.abstractmethod
async def encode(self, item: t.Any, **options):
async def decode(self, item: t.Any, **options) -> t.Any:
...


class HTTPCodec(Codec):
media_type: t.Optional[str] = None

async def decode(self, item: "http.Request", **options):
...

async def encode(self, item: t.Any, **options):
@abc.abstractmethod
async def decode(self, item: "http.Request", **options) -> t.Any:
...


class WebsocketsCodec(Codec):
encoding: t.Optional[str] = None

async def decode(self, item: "types.Message", **options):
...

async def encode(self, item: t.Any, **options):
@abc.abstractmethod
async def decode(self, item: "types.Message", **options) -> t.Any:
...
9 changes: 7 additions & 2 deletions flama/codecs/http/jsondata.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
from flama import exceptions, http
import typing as t

from flama import exceptions
from flama.codecs.base import HTTPCodec

if t.TYPE_CHECKING:
from flama import http

__all__ = ["JSONDataCodec"]


class JSONDataCodec(HTTPCodec):
media_type = "application/json"
format = "json"

async def decode(self, item: http.Request, **options):
async def decode(self, item: "http.Request", **options):
try:
if await item.body() == b"":
return None
Expand Down
8 changes: 6 additions & 2 deletions flama/codecs/http/multipart.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from flama import http
import typing as t

from flama.codecs.base import HTTPCodec

if t.TYPE_CHECKING:
from flama import http

__all__ = ["MultiPartCodec"]


class MultiPartCodec(HTTPCodec):
media_type = "multipart/form-data"

async def decode(self, item: http.Request, **options):
async def decode(self, item: "http.Request", **options) -> t.Any:
return await item.form()
8 changes: 6 additions & 2 deletions flama/codecs/http/urlencoded.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from flama import http
import typing as t

from flama.codecs.base import HTTPCodec

if t.TYPE_CHECKING:
from flama import http

__all__ = ["URLEncodedCodec"]


class URLEncodedCodec(HTTPCodec):
media_type = "application/x-www-form-urlencoded"

async def decode(self, item: http.Request, **options):
async def decode(self, item: "http.Request", **options):
return await item.form() or None
4 changes: 2 additions & 2 deletions flama/concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ class AsyncProcess(multiprocessing.Process):
_args: t.List[t.Any]
_kwargs: t.Dict[str, t.Any]

def run(self):
def run(self) -> None:
if self._target:
result_or_task = self._target(*self._args, **self._kwargs)

return asyncio.run(result_or_task) if is_async(self._target) else result_or_task
asyncio.run(result_or_task) if is_async(self._target) else result_or_task
18 changes: 11 additions & 7 deletions flama/ddd/repositories.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,30 @@
import sqlalchemy
import sqlalchemy.exc
except Exception: # pragma: no cover
raise AssertionError("`sqlalchemy[asyncio]` must be installed to use crud resources") from None
raise AssertionError("`sqlalchemy[asyncio]` must be installed to use ddd") from None


if t.TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncConnection

try:
from sqlalchemy.ext.asyncio import AsyncConnection
except Exception: # pragma: no cover
...

__all__ = ["AbstractRepository", "SQLAlchemyRepository", "SQLAlchemyTableRepository", "SQLAlchemyTableManager"]


class AbstractRepository(abc.ABC):
"""Base class for repositories."""

...
def __init__(self, *args, **kwargs):
...


class SQLAlchemyRepository(AbstractRepository):
"""Base class for SQLAlchemy repositories. It provides a connection to the database."""

def __init__(self, connection: "AsyncConnection"):
def __init__(self, connection: "AsyncConnection", *args, **kwargs):
super().__init__(*args, **kwargs)
self._connection = connection

def __eq__(self, other):
Expand Down Expand Up @@ -220,8 +224,8 @@ def _filter_query(self, query, *clauses, **filters):
class SQLAlchemyTableRepository(SQLAlchemyRepository):
_table: t.ClassVar[sqlalchemy.Table]

def __init__(self, connection: "AsyncConnection"):
super().__init__(connection)
def __init__(self, connection: "AsyncConnection", *args, **kwargs):
super().__init__(connection, *args, **kwargs)
self._table_manager = SQLAlchemyTableManager(self._table, connection)

def __eq__(self, other):
Expand Down
15 changes: 10 additions & 5 deletions flama/ddd/workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
import inspect
import typing as t

from sqlalchemy.ext.asyncio import AsyncTransaction

from flama.ddd import types
from flama.ddd.repositories import AbstractRepository, SQLAlchemyRepository
from flama.ddd.repositories import AbstractRepository
from flama.exceptions import ApplicationError

if t.TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncConnection
try:
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncTransaction
except Exception: # pragma: no cover
...

from flama import Flama

Expand Down Expand Up @@ -121,7 +122,11 @@ async def rollback(self) -> None:


class SQLAlchemyWorker(AbstractWorker):
_repositories: t.ClassVar[t.Dict[str, t.Type[SQLAlchemyRepository]]]
"""Worker for SQLAlchemy.
It will provide a connection and a transaction to the database and create the repositories for the entities.
"""

_connection: "AsyncConnection"
_transaction: "AsyncTransaction"

Expand Down
44 changes: 22 additions & 22 deletions flama/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,22 +66,22 @@ class PlainTextResponse(starlette.responses.PlainTextResponse, Response):


class EnhancedJSONEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, (Path, os.PathLike)):
return str(obj)
if isinstance(obj, (bytes, bytearray)):
return obj.decode("utf-8")
if isinstance(obj, enum.Enum):
return obj.value
if isinstance(obj, uuid.UUID):
return str(obj)
if isinstance(obj, (set, frozenset)):
return list(obj)
if isinstance(obj, (datetime.datetime, datetime.date, datetime.time)):
return obj.isoformat()
if isinstance(obj, datetime.timedelta):
def default(self, o):
if isinstance(o, (Path, os.PathLike)):
return str(o)
if isinstance(o, (bytes, bytearray)):
return o.decode("utf-8")
if isinstance(o, enum.Enum):
return o.value
if isinstance(o, uuid.UUID):
return str(o)
if isinstance(o, (set, frozenset)):
return list(o)
if isinstance(o, (datetime.datetime, datetime.date, datetime.time)):
return o.isoformat()
if isinstance(o, datetime.timedelta):
# split seconds to larger units
seconds = obj.total_seconds()
seconds = o.total_seconds()
minutes, seconds = divmod(seconds, 60)
hours, minutes = divmod(minutes, 60)
days, hours = divmod(hours, 24)
Expand All @@ -96,13 +96,13 @@ def default(self, obj):
)

return "P" + "".join([formatted_value for value, formatted_value in formatted_units if value])
if inspect.isclass(obj) and issubclass(obj, BaseException):
return obj.__name__
if isinstance(obj, BaseException):
return repr(obj)
if dataclasses.is_dataclass(obj):
return dataclasses.asdict(obj)
return super().default(obj)
if inspect.isclass(o) and issubclass(o, BaseException):
return o.__name__
if isinstance(o, BaseException):
return repr(o)
if dataclasses.is_dataclass(o):
return dataclasses.asdict(o)
return super().default(o)


class JSONResponse(starlette.responses.JSONResponse, Response):
Expand Down
4 changes: 0 additions & 4 deletions flama/injection/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,6 @@ async def __call__(self, *args, **kwargs):
def __str__(self) -> str:
return str(self.__class__.__name__)

@abc.abstractmethod
def resolve(self, *args, **kwargs) -> t.Any:
...


class Components(t.Tuple[Component, ...]):
def __new__(cls, components: t.Optional[t.Union[t.Sequence[Component], t.Set[Component]]] = None):
Expand Down
2 changes: 1 addition & 1 deletion flama/injection/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ async def value(self, context: t.Dict[str, t.Any]) -> t.Any:
return await self.root.value(context)


class ResolutionCache(t.Mapping[str, ResolutionTree]):
class ResolutionCache(t.Mapping[int, ResolutionTree]):
"""A cache for resolution trees."""

def __init__(self):
Expand Down
2 changes: 1 addition & 1 deletion flama/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __call__(self, app: "types.App") -> t.Union["types.MiddlewareClass", "types.
def __repr__(self) -> str:
name = self.__class__.__name__
middleware_name = (
self.middleware.__name__ if inspect.isfunction(self.middleware) else self.middleware.__class__.__name__
self.middleware.__class__.__name__ if inspect.isclass(self.middleware) else self.middleware.__name__
)
args = ", ".join([middleware_name] + [f"{key}={value!r}" for key, value in self.kwargs.items()])
return f"{name}({args})"
Expand Down
8 changes: 4 additions & 4 deletions flama/pagination/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
import inspect
import typing as t

from flama import schemas, types
from flama import types


class PaginationDecoratorFactory:
PARAMETERS: t.List[inspect.Parameter]

@classmethod
def decorate(cls, func: t.Callable, schema: schemas.Schema) -> t.Callable:
def decorate(cls, func: t.Callable, schema: t.Type[types.Schema]) -> t.Callable:
func_signature = inspect.signature(func)
if "kwargs" not in func_signature.parameters:
raise TypeError("Paginated views must define **kwargs param")
Expand All @@ -31,10 +31,10 @@ def decorate(cls, func: t.Callable, schema: schemas.Schema) -> t.Callable:

@classmethod
@abc.abstractmethod
def _decorate_async(cls, func: t.Callable, schema: schemas.Schema) -> t.Callable:
def _decorate_async(cls, func: t.Callable, schema: t.Type[types.Schema]) -> t.Callable:
...

@classmethod
@abc.abstractmethod
def _decorate_sync(cls, func: t.Callable, schema: schemas.Schema) -> t.Callable:
def _decorate_sync(cls, func: t.Callable, schema: t.Type[types.Schema]) -> t.Callable:
...
14 changes: 7 additions & 7 deletions flama/pagination/mixins/limit_offset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import inspect
import typing as t

from flama import http, schemas
from flama import http, schemas, types

__all__ = ["LimitOffsetMixin", "LimitOffsetResponse"]

Expand All @@ -23,7 +23,7 @@ class LimitOffsetResponse(http.APIResponse):

def __init__(
self,
schema: schemas.Schema,
schema: t.Type[types.Schema],
offset: t.Optional[t.Union[int, str]] = None,
limit: t.Optional[t.Union[int, str]] = None,
count: t.Optional[bool] = True,
Expand Down Expand Up @@ -59,7 +59,7 @@ class LimitOffsetDecoratorFactory(PaginationDecoratorFactory):
]

@classmethod
def _decorate_async(cls, func: t.Callable, schema: schemas.Schema) -> t.Callable:
def _decorate_async(cls, func: t.Callable, schema: t.Type[types.Schema]) -> t.Callable:
@functools.wraps(func)
async def decorator(
*args,
Expand All @@ -75,7 +75,7 @@ async def decorator(
return decorator

@classmethod
def _decorate_sync(cls, func: t.Callable, schema: schemas.Schema) -> t.Callable:
def _decorate_sync(cls, func: t.Callable, schema: t.Type[types.Schema]) -> t.Callable:
@functools.wraps(func)
def decorator(
*args,
Expand Down Expand Up @@ -106,9 +106,9 @@ def _paginate_limit_offset(self, func: t.Callable) -> t.Callable:
:param schema_name: Name used for output field.
:return: Decorated view.
"""
schema = schemas.Schema.from_type(inspect.signature(func).return_annotation)
resource_schema = schema.unique_schema
schema_name = schema.name
schema_wrapped = schemas.Schema.from_type(inspect.signature(func).return_annotation)
resource_schema = schema_wrapped.unique_schema
schema_name = schema_wrapped.name

try:
schema_module, schema_class = schema_name.rsplit(".", 1)
Expand Down
Loading

0 comments on commit 2664aba

Please sign in to comment.