Skip to content

Commit

Permalink
🐛 DependencyNotInstalled on modules that requires sqlalchemy (#162)
Browse files Browse the repository at this point in the history
  • Loading branch information
perdy committed Nov 25, 2024
1 parent 64ec041 commit e7ab2f4
Show file tree
Hide file tree
Showing 13 changed files with 70 additions and 165 deletions.
1 change: 0 additions & 1 deletion flama/ddd/repositories/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
from flama.ddd.repositories.base import * # noqa
from flama.ddd.repositories.http import * # noqa
from flama.ddd.repositories.sqlalchemy import * # noqa
78 changes: 7 additions & 71 deletions flama/ddd/repositories/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,20 @@
try:
import sqlalchemy
import sqlalchemy.exc as sqlalchemy_exceptions
from sqlalchemy.ext.asyncio import AsyncConnection
except Exception: # pragma: no cover
sqlalchemy = None
sqlalchemy_exceptions = None
raise exceptions.DependencyNotInstalled(
dependency=exceptions.DependencyNotInstalled.Dependency.sqlalchemy, dependant=__name__
)


if t.TYPE_CHECKING:
try:
from sqlalchemy.ext.asyncio import AsyncConnection
except Exception: # pragma: no cover
...

__all__ = ["SQLAlchemyRepository", "SQLAlchemyTableManager", "SQLAlchemyTableRepository"]


class SQLAlchemyRepository(AbstractRepository):
"""Base class for SQLAlchemy repositories. It provides a connection to the database."""

def __init__(self, connection: "AsyncConnection", *args, **kwargs):
if sqlalchemy is None:
raise exceptions.DependencyNotInstalled(
dependency=exceptions.DependencyNotInstalled.Dependency.sqlalchemy,
dependant=f"{self.__class__.__module__}.{self.__class__.__name__}",
)

def __init__(self, connection: AsyncConnection, *args, **kwargs):
super().__init__(*args, **kwargs)
self._connection = connection

Expand All @@ -39,13 +29,7 @@ def __eq__(self, other):


class SQLAlchemyTableManager:
def __init__(self, table: sqlalchemy.Table, connection: "AsyncConnection"): # type: ignore
if sqlalchemy is None:
raise exceptions.DependencyNotInstalled(
dependency=exceptions.DependencyNotInstalled.Dependency.sqlalchemy,
dependant=f"{self.__class__.__module__}.{self.__class__.__name__}",
)

def __init__(self, table: sqlalchemy.Table, connection: AsyncConnection): # type: ignore
self._connection = connection
self.table = table

Expand All @@ -66,12 +50,6 @@ async def create(self, *data: dict[str, t.Any]) -> list[dict[str, t.Any]]:
:return: The created elements.
:raises IntegrityError: If the element already exists or cannot be inserted.
"""
if sqlalchemy is None or sqlalchemy_exceptions is None:
raise exceptions.DependencyNotInstalled(
dependency=exceptions.DependencyNotInstalled.Dependency.sqlalchemy,
dependant=f"{self.__class__.__module__}.{self.__class__.__name__}",
)

try:
result = await self._connection.execute(sqlalchemy.insert(self.table).values(data).returning(self.table))
except sqlalchemy_exceptions.IntegrityError as e:
Expand All @@ -94,12 +72,6 @@ async def retrieve(self, *clauses, **filters) -> dict[str, t.Any]:
:raises NotFoundError: If the element does not exist.
:raises MultipleRecordsError: If more than one element is found.
"""
if sqlalchemy is None or sqlalchemy_exceptions is None:
raise exceptions.DependencyNotInstalled(
dependency=exceptions.DependencyNotInstalled.Dependency.sqlalchemy,
dependant=f"{self.__class__.__module__}.{self.__class__.__name__}",
)

query = self._filter_query(sqlalchemy.select(self.table), *clauses, **filters)

try:
Expand All @@ -121,12 +93,6 @@ async def update(self, data: dict[str, t.Any], *clauses, **filters) -> list[dict
:return: The updated elements.
:raises IntegrityError: If the elements cannot be updated.
"""
if sqlalchemy is None or sqlalchemy_exceptions is None:
raise exceptions.DependencyNotInstalled(
dependency=exceptions.DependencyNotInstalled.Dependency.sqlalchemy,
dependant=f"{self.__class__.__module__}.{self.__class__.__name__}",
)

query = (
self._filter_query(sqlalchemy.update(self.table), *clauses, **filters).values(**data).returning(self.table)
)
Expand All @@ -152,12 +118,6 @@ async def delete(self, *clauses, **filters) -> None:
:raises NotFoundError: If the element does not exist.
:raises MultipleRecordsError: If more than one element is found.
"""
if sqlalchemy is None or sqlalchemy_exceptions is None:
raise exceptions.DependencyNotInstalled(
dependency=exceptions.DependencyNotInstalled.Dependency.sqlalchemy,
dependant=f"{self.__class__.__module__}.{self.__class__.__name__}",
)

await self.retrieve(*clauses, **filters)

query = self._filter_query(sqlalchemy.delete(self.table), *clauses, **filters)
Expand Down Expand Up @@ -185,12 +145,6 @@ async def list(
:param filters: Filters to filter the elements.
:return: Async iterable of the elements.
"""
if sqlalchemy is None or sqlalchemy_exceptions is None:
raise exceptions.DependencyNotInstalled(
dependency=exceptions.DependencyNotInstalled.Dependency.sqlalchemy,
dependant=f"{self.__class__.__module__}.{self.__class__.__name__}",
)

query = self._filter_query(sqlalchemy.select(self.table), *clauses, **filters)

if order_by:
Expand Down Expand Up @@ -219,12 +173,6 @@ async def drop(self, *clauses, **filters) -> int:
:param filters: Filters to filter the elements.
:return: The number of elements dropped.
"""
if sqlalchemy is None or sqlalchemy_exceptions is None:
raise exceptions.DependencyNotInstalled(
dependency=exceptions.DependencyNotInstalled.Dependency.sqlalchemy,
dependant=f"{self.__class__.__module__}.{self.__class__.__name__}",
)

query = self._filter_query(sqlalchemy.delete(self.table), *clauses, **filters)

result = await self._connection.execute(query)
Expand All @@ -248,12 +196,6 @@ def _filter_query(self, query, *clauses, **filters):
:param filters: Filters to filter the elements.
:return: The filtered query.
"""
if sqlalchemy is None or sqlalchemy_exceptions is None:
raise exceptions.DependencyNotInstalled(
dependency=exceptions.DependencyNotInstalled.Dependency.sqlalchemy,
dependant=f"{self.__class__.__module__}.{self.__class__.__name__}",
)

where_clauses = tuple(clauses) + tuple(self.table.c[k] == v for k, v in filters.items())

if where_clauses:
Expand All @@ -265,13 +207,7 @@ def _filter_query(self, query, *clauses, **filters):
class SQLAlchemyTableRepository(SQLAlchemyRepository):
_table: t.ClassVar[sqlalchemy.Table] # type: ignore

def __init__(self, connection: "AsyncConnection", *args, **kwargs):
if sqlalchemy is None or sqlalchemy_exceptions is None:
raise exceptions.DependencyNotInstalled(
dependency=exceptions.DependencyNotInstalled.Dependency.sqlalchemy,
dependant=f"{self.__class__.__module__}.{self.__class__.__name__}",
)

def __init__(self, connection: AsyncConnection, *args, **kwargs):
super().__init__(connection, *args, **kwargs)
self._table_manager = SQLAlchemyTableManager(self._table, connection)

Expand Down
1 change: 0 additions & 1 deletion flama/ddd/workers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
from flama.ddd.workers.base import * # noqa
from flama.ddd.workers.http import * # noqa
from flama.ddd.workers.sqlalchemy import * # noqa
17 changes: 11 additions & 6 deletions flama/ddd/workers/sqlalchemy.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import logging
import typing as t

from flama import exceptions
from flama.ddd.workers.base import AbstractWorker

if t.TYPE_CHECKING:
try:
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncTransaction
except Exception: # pragma: no cover
raise exceptions.DependencyNotInstalled(
dependency=exceptions.DependencyNotInstalled.Dependency.sqlalchemy, dependant=__name__
)


__all__ = ["SQLAlchemyWorker"]

Expand All @@ -17,11 +22,11 @@ class SQLAlchemyWorker(AbstractWorker):
It will provide a connection and a transaction to the database and create the repositories for the entities.
"""

_connection: "AsyncConnection"
_transaction: "AsyncTransaction"
_connection: AsyncConnection
_transaction: AsyncTransaction

@property
def connection(self) -> "AsyncConnection":
def connection(self) -> AsyncConnection:
"""Connection to the database.
:return: Connection to the database.
Expand All @@ -33,7 +38,7 @@ def connection(self) -> "AsyncConnection":
raise AttributeError("Connection not initialized")

@property
def transaction(self) -> "AsyncTransaction":
def transaction(self) -> AsyncTransaction:
"""Database transaction.
:return: Database transaction.
Expand Down
20 changes: 5 additions & 15 deletions flama/resources/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import uuid

from flama import exceptions
from flama.ddd.repositories import SQLAlchemyTableRepository
from flama.ddd.repositories.sqlalchemy import SQLAlchemyTableRepository
from flama.resources import data_structures
from flama.resources.exceptions import ResourceAttributeError
from flama.resources.resource import Resource, ResourceType
Expand All @@ -12,8 +12,9 @@
import sqlalchemy
from sqlalchemy.dialects import postgresql
except Exception: # pragma: no cover
sqlalchemy = None
postgresql = None
raise exceptions.DependencyNotInstalled(
dependency=exceptions.DependencyNotInstalled.Dependency.sqlalchemy, dependant=__name__
)

__all__ = ["RESTResource", "RESTResourceType"]

Expand All @@ -28,11 +29,6 @@ def __new__(mcs, name: str, bases: tuple[type], namespace: dict[str, t.Any]):
:param bases: List of superclasses.
:param namespace: Variables namespace used to create the class.
"""
if sqlalchemy is None:
raise exceptions.DependencyNotInstalled(
dependency=exceptions.DependencyNotInstalled.Dependency.sqlalchemy, dependant="RESTResourceType"
)

if not mcs._is_abstract(namespace):
try:
# Get model
Expand Down Expand Up @@ -68,12 +64,6 @@ def _get_model(cls, bases: t.Sequence[t.Any], namespace: dict[str, t.Any]) -> da
:param namespace: Variables namespace used to create the class.
:return: Resource model.
"""
if sqlalchemy is None or postgresql is None:
raise exceptions.DependencyNotInstalled(
dependency=exceptions.DependencyNotInstalled.Dependency.sqlalchemy,
dependant=f"{cls.__module__}.{cls.__name__}",
)

model = cls._get_attribute("model", bases, namespace, metadata_namespace="rest")

# Already defined model probably because resource inheritance, so no need to create it
Expand Down Expand Up @@ -154,7 +144,7 @@ def _get_schemas(cls, name: str, bases: t.Sequence[t.Any], namespace: dict[str,


class RESTResource(Resource, metaclass=RESTResourceType):
model: sqlalchemy.Table # type: ignore
model: sqlalchemy.Table
schema: t.Any
input_schema: t.Any
output_schema: t.Any
8 changes: 4 additions & 4 deletions flama/resources/workers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import typing as t

from flama.ddd import SQLAlchemyWorker
from flama.ddd.workers.sqlalchemy import SQLAlchemyWorker
from flama.exceptions import ApplicationError

if t.TYPE_CHECKING:
from flama import Flama
from flama.ddd.repositories import SQLAlchemyTableRepository
from flama.ddd.repositories.sqlalchemy import SQLAlchemyTableRepository


class FlamaWorker(SQLAlchemyWorker):
Expand All @@ -20,8 +20,8 @@ def __init__(self, app: t.Optional["Flama"] = None):
"""

super().__init__(app)
self._repositories: dict[str, type["SQLAlchemyTableRepository"]] = {} # type: ignore
self._init_repositories: t.Optional[dict[str, "SQLAlchemyTableRepository"]] = None
self._repositories: dict[str, type[SQLAlchemyTableRepository]] = {} # type: ignore
self._init_repositories: t.Optional[dict[str, SQLAlchemyTableRepository]] = None

@property
def repositories(self) -> dict[str, "SQLAlchemyTableRepository"]:
Expand Down
Loading

0 comments on commit e7ab2f4

Please sign in to comment.