Skip to content

Commit

Permalink
✨ Use typing.Annotated for schema types (#158)
Browse files Browse the repository at this point in the history
  • Loading branch information
perdy committed Nov 24, 2024
1 parent 7d0adf1 commit de06aec
Show file tree
Hide file tree
Showing 26 changed files with 254 additions and 207 deletions.
45 changes: 21 additions & 24 deletions flama/ddd/repositories/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import httpx

from flama import types
from flama.ddd import exceptions
from flama.ddd.repositories import AbstractRepository

Expand Down Expand Up @@ -34,7 +33,7 @@ def __eq__(self, other):
isinstance(other, HTTPResourceManager) and self._client == other._client and self.resource == other.resource
)

async def create(self, data: t.Union[dict[str, t.Any], types.Schema]) -> types.Schema:
async def create(self, data: dict[str, t.Any]) -> dict[str, t.Any]:
"""Create a new element in the collection.
:param data: The data to create the element.
Expand All @@ -49,9 +48,9 @@ async def create(self, data: t.Union[dict[str, t.Any], types.Schema]) -> types.S
raise exceptions.IntegrityError()
raise

return types.Schema(response.json())
return response.json()

async def retrieve(self, id: t.Union[str, uuid.UUID]) -> types.Schema:
async def retrieve(self, id: t.Union[str, uuid.UUID]) -> dict[str, t.Any]:
"""Retrieve an element from the collection.
:param id: The id of the element.
Expand All @@ -66,9 +65,9 @@ async def retrieve(self, id: t.Union[str, uuid.UUID]) -> types.Schema:
raise exceptions.NotFoundError()
raise

return types.Schema(response.json())
return response.json()

async def update(self, id: t.Union[str, uuid.UUID], data: t.Union[dict[str, t.Any], types.Schema]) -> types.Schema:
async def update(self, id: t.Union[str, uuid.UUID], data: dict[str, t.Any]) -> dict[str, t.Any]:
"""Update an element in the collection.
:param id: The id of the element.
Expand All @@ -86,11 +85,9 @@ async def update(self, id: t.Union[str, uuid.UUID], data: t.Union[dict[str, t.An
if e.response.status_code == http.HTTPStatus.BAD_REQUEST:
raise exceptions.IntegrityError()
raise
return types.Schema(response.json())
return response.json()

async def partial_update(
self, id: t.Union[str, uuid.UUID], data: t.Union[dict[str, t.Any], types.Schema]
) -> types.Schema:
async def partial_update(self, id: t.Union[str, uuid.UUID], data: dict[str, t.Any]) -> dict[str, t.Any]:
"""Partially update an element in the collection.
:param id: The id of the element.
Expand All @@ -108,7 +105,7 @@ async def partial_update(
if e.response.status_code == http.HTTPStatus.BAD_REQUEST:
raise exceptions.IntegrityError()
raise
return types.Schema(response.json())
return response.json()

async def delete(self, id: t.Union[str, uuid.UUID]) -> None:
"""Delete an element from the collection.
Expand Down Expand Up @@ -168,7 +165,7 @@ async def _limit_offset_paginated(self) -> t.AsyncIterable[dict[str, t.Any]]:
except exceptions.Empty:
break

async def list(self, *, pagination: str = "page_number") -> t.AsyncIterable[types.Schema]:
async def list(self, *, pagination: str = "page_number") -> t.AsyncIterable[dict[str, t.Any]]:
"""List all the elements in the collection.
:param pagination: The pagination technique.
Expand All @@ -178,9 +175,9 @@ async def list(self, *, pagination: str = "page_number") -> t.AsyncIterable[type
iterator = self._page_number_paginated() if pagination == "page_number" else self._limit_offset_paginated()

async for element in iterator:
yield types.Schema(element)
yield element

async def replace(self, data: builtins.list[dict[str, t.Any]]) -> builtins.list[types.Schema]:
async def replace(self, data: builtins.list[dict[str, t.Any]]) -> builtins.list[dict[str, t.Any]]:
"""Replace elements in the collection.
:param data: The data to replace the elements.
Expand All @@ -194,9 +191,9 @@ async def replace(self, data: builtins.list[dict[str, t.Any]]) -> builtins.list[
raise exceptions.IntegrityError()
raise

return [types.Schema(element) for element in response.json()]
return [element for element in response.json()]

async def partial_replace(self, data: builtins.list[dict[str, t.Any]]) -> builtins.list[types.Schema]:
async def partial_replace(self, data: builtins.list[dict[str, t.Any]]) -> builtins.list[dict[str, t.Any]]:
"""Partially replace elements in the collection.
:param data: The data to replace the elements.
Expand All @@ -210,7 +207,7 @@ async def partial_replace(self, data: builtins.list[dict[str, t.Any]]) -> builti
raise exceptions.IntegrityError()
raise

return [types.Schema(element) for element in response.json()]
return [element for element in response.json()]

async def drop(self) -> int:
"""Drop the collection.
Expand Down Expand Up @@ -238,23 +235,23 @@ def __init__(self, client: "Client"):
super().__init__(client)
self._resource_manager = HTTPResourceManager(self._resource, client)

async def create(self, data: dict[str, t.Any]) -> types.Schema:
async def create(self, data: dict[str, t.Any]) -> dict[str, t.Any]:
"""Create a new element in the collection.
:param data: The data to create the element.
:return: The element created.
"""
return await self._resource_manager.create(data)

async def retrieve(self, id: uuid.UUID) -> types.Schema:
async def retrieve(self, id: uuid.UUID) -> dict[str, t.Any]:
"""Retrieve an element from the collection.
:param id: The id of the element.
:return: The element retrieved.
"""
return await self._resource_manager.retrieve(id)

async def update(self, id: uuid.UUID, data: dict[str, t.Any]) -> types.Schema:
async def update(self, id: uuid.UUID, data: dict[str, t.Any]) -> dict[str, t.Any]:
"""Update an element in the collection.
:param id: The id of the element.
Expand All @@ -263,7 +260,7 @@ async def update(self, id: uuid.UUID, data: dict[str, t.Any]) -> types.Schema:
"""
return await self._resource_manager.update(id, data)

async def partial_update(self, id: uuid.UUID, data: dict[str, t.Any]) -> types.Schema:
async def partial_update(self, id: uuid.UUID, data: dict[str, t.Any]) -> dict[str, t.Any]:
"""Partially update an element in the collection.
:param id: The id of the element.
Expand All @@ -279,23 +276,23 @@ async def delete(self, id: uuid.UUID) -> None:
"""
return await self._resource_manager.delete(id)

def list(self, *, pagination: str = "page_number") -> t.AsyncIterable[types.Schema]:
def list(self, *, pagination: str = "page_number") -> t.AsyncIterable[dict[str, t.Any]]:
"""List all the elements in the collection.
:param pagination: The pagination technique.
:return: Async iterable of the elements.
"""
return self._resource_manager.list(pagination=pagination)

async def replace(self, data: builtins.list[dict[str, t.Any]]) -> builtins.list[types.Schema]:
async def replace(self, data: builtins.list[dict[str, t.Any]]) -> builtins.list[dict[str, t.Any]]:
"""Replace elements in the collection.
:param data: The data to replace the elements.
:return: The elements replaced.
"""
return await self._resource_manager.replace(data)

async def partial_replace(self, data: builtins.list[dict[str, t.Any]]) -> builtins.list[types.Schema]:
async def partial_replace(self, data: builtins.list[dict[str, t.Any]]) -> builtins.list[dict[str, t.Any]]:
"""Partially replace elements in the collection.
:param data: The data to replace the elements.
Expand Down
25 changes: 12 additions & 13 deletions flama/ddd/repositories/sqlalchemy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import typing as t

from flama import types
from flama.ddd import exceptions
from flama.ddd.repositories import AbstractRepository

Expand Down Expand Up @@ -41,7 +40,7 @@ def __eq__(self, other):
and self.table == other.table
)

async def create(self, *data: t.Union[dict[str, t.Any], types.Schema]) -> list[types.Schema]:
async def create(self, *data: dict[str, t.Any]) -> list[dict[str, t.Any]]:
"""Creates new elements in the table.
If the element already exists, it raises an `IntegrityError`. If the element is created, it returns
Expand All @@ -55,9 +54,9 @@ async def create(self, *data: t.Union[dict[str, t.Any], types.Schema]) -> list[t
result = await self._connection.execute(sqlalchemy.insert(self.table).values(data).returning(self.table))
except sqlalchemy.exc.IntegrityError as e:
raise exceptions.IntegrityError(str(e))
return [types.Schema(element._asdict()) for element in result]
return [dict[str, t.Any](element._asdict()) for element in result]

async def retrieve(self, *clauses, **filters) -> types.Schema:
async def retrieve(self, *clauses, **filters) -> dict[str, t.Any]:
"""Retrieves an element from the table.
If the element does not exist, it raises a `NotFoundError`. If more than one element is found, it raises a
Expand All @@ -82,9 +81,9 @@ async def retrieve(self, *clauses, **filters) -> types.Schema:
except sqlalchemy.exc.MultipleResultsFound:
raise exceptions.MultipleRecordsError()

return types.Schema(element._asdict())
return dict[str, t.Any](element._asdict())

async def update(self, data: t.Union[dict[str, t.Any], types.Schema], *clauses, **filters) -> list[types.Schema]:
async def update(self, data: dict[str, t.Any], *clauses, **filters) -> list[dict[str, t.Any]]:
"""Updates elements in the table.
Using clauses and filters, it filters the elements to update. If no clauses or filters are given, it updates
Expand All @@ -103,7 +102,7 @@ async def update(self, data: t.Union[dict[str, t.Any], types.Schema], *clauses,
except sqlalchemy.exc.IntegrityError:
raise exceptions.IntegrityError

return [types.Schema(element._asdict()) for element in result]
return [dict[str, t.Any](element._asdict()) for element in result]

async def delete(self, *clauses, **filters) -> None:
"""Delete an element from the table.
Expand All @@ -127,7 +126,7 @@ async def delete(self, *clauses, **filters) -> None:

async def list(
self, *clauses, order_by: t.Optional[str] = None, order_direction: str = "asc", **filters
) -> t.AsyncIterable[types.Schema]:
) -> t.AsyncIterable[dict[str, t.Any]]:
"""Lists all the elements in the table.
If no elements are found, it returns an empty list. If no clauses or filters are given, it returns all the
Expand Down Expand Up @@ -156,7 +155,7 @@ async def list(
result = await self._connection.stream(query)

async for row in result:
yield types.Schema(row._asdict())
yield dict[str, t.Any](row._asdict())

async def drop(self, *clauses, **filters) -> int:
"""Drops elements in the table.
Expand Down Expand Up @@ -215,7 +214,7 @@ def __init__(self, connection: "AsyncConnection", *args, **kwargs):
def __eq__(self, other):
return isinstance(other, SQLAlchemyTableRepository) and self._table == other._table and super().__eq__(other)

async def create(self, *data: t.Union[dict[str, t.Any], types.Schema]) -> list[types.Schema]:
async def create(self, *data: dict[str, t.Any]) -> list[dict[str, t.Any]]:
"""Creates new elements in the repository.
If the element already exists, it raises an `exceptions.IntegrityError`. If the element is created, it returns
Expand All @@ -227,7 +226,7 @@ async def create(self, *data: t.Union[dict[str, t.Any], types.Schema]) -> list[t
"""
return await self._table_manager.create(*data)

async def retrieve(self, *clauses, **filters) -> types.Schema:
async def retrieve(self, *clauses, **filters) -> dict[str, t.Any]:
"""Retrieves an element from the repository.
If the element does not exist, it raises a `NotFoundError`. If more than one element is found, it raises a
Expand All @@ -247,7 +246,7 @@ async def retrieve(self, *clauses, **filters) -> types.Schema:
"""
return await self._table_manager.retrieve(*clauses, **filters)

async def update(self, data: t.Union[dict[str, t.Any], types.Schema], *clauses, **filters) -> list[types.Schema]:
async def update(self, data: dict[str, t.Any], *clauses, **filters) -> list[dict[str, t.Any]]:
"""Updates an element in the repository.
If the element does not exist, it raises a `NotFoundError`. If the element is updated, it returns the updated
Expand Down Expand Up @@ -280,7 +279,7 @@ async def delete(self, *clauses, **filters) -> None:

def list(
self, *clauses, order_by: t.Optional[str] = None, order_direction: str = "asc", **filters
) -> t.AsyncIterable[types.Schema]:
) -> t.AsyncIterable[dict[str, t.Any]]:
"""Lists all the elements in the repository.
Lists all the elements in the repository that match the clauses and filters. If no clauses or filters are given,
Expand Down
11 changes: 6 additions & 5 deletions flama/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,6 @@ async def __call__( # type: ignore[override]
await super().__call__(scope, receive, send) # type: ignore[arg-type]

def render(self, content: t.Any) -> bytes:
if isinstance(content, types.Schema):
content = dict(content)

return json.dumps(
content, ensure_ascii=False, allow_nan=False, indent=None, separators=(",", ":"), cls=EnhancedJSONEncoder
).encode("utf-8")
Expand Down Expand Up @@ -147,7 +144,7 @@ async def __call__( # type: ignore[override]
class APIResponse(JSONResponse):
media_type = "application/json"

def __init__(self, content: t.Any = None, schema: t.Optional[type["types.Schema"]] = None, *args, **kwargs):
def __init__(self, content: t.Any = None, schema: t.Any = None, *args, **kwargs):
self.schema = schema
super().__init__(content, *args, **kwargs)

Expand Down Expand Up @@ -182,7 +179,11 @@ def __init__(
}

super().__init__(
content, schema=types.Schema[schemas.schemas.APIError], status_code=status_code, *args, **kwargs
content,
schema=t.Annotated[schemas.Schema, schemas.SchemaMetadata(schemas.schemas.APIError)],
status_code=status_code,
*args,
**kwargs,
)

self.detail = detail
Expand Down
8 changes: 4 additions & 4 deletions flama/models/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import typing as t

import flama.schemas
from flama import types
from flama import schemas
from flama.models.components import ModelComponentBuilder
from flama.resources import data_structures
from flama.resources.exceptions import ResourceAttributeError
Expand Down Expand Up @@ -46,9 +46,9 @@ def _add_predict(cls, name: str, verbose_name: str, model_model_type: type["Mode
async def predict(
self,
model: model_model_type, # type: ignore[valid-type]
data: types.Schema[flama.schemas.schemas.MLModelInput], # type: ignore[type-arg]
) -> types.Schema[flama.schemas.schemas.MLModelOutput]: # type: ignore[type-arg]
return types.Schema({"output": model.predict(data["input"])}) # type: ignore[attr-defined]
data: t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(flama.schemas.schemas.MLModelInput)],
) -> t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(flama.schemas.schemas.MLModelOutput)]:
return {"output": model.predict(data["input"])}

predict.__doc__ = f"""
tags:
Expand Down
10 changes: 4 additions & 6 deletions flama/pagination/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@
import inspect
import typing as t

from flama import types


class PaginationDecoratorFactory:
PARAMETERS: list[inspect.Parameter]

@classmethod
def decorate(cls, func: t.Callable, schema: type[types.Schema]) -> t.Callable:
def decorate(cls, func: t.Callable, schema: t.Any) -> t.Callable:
func_signature = inspect.signature(func)
if "kwargs" not in func_signature.parameters:
raise TypeError("Paginated views must define **kwargs param")
Expand All @@ -24,17 +22,17 @@ def decorate(cls, func: t.Callable, schema: type[types.Schema]) -> t.Callable:
*[v for k, v in func_signature.parameters.items() if k != "kwargs"],
*cls.PARAMETERS,
],
return_annotation=types.Schema[schema], # type: ignore
return_annotation=schema,
)

return decorated_func

@classmethod
@abc.abstractmethod
def _decorate_async(cls, func: t.Callable, schema: type[types.Schema]) -> t.Callable:
def _decorate_async(cls, func: t.Callable, schema: t.Any) -> t.Callable:
...

@classmethod
@abc.abstractmethod
def _decorate_sync(cls, func: t.Callable, schema: type[types.Schema]) -> t.Callable:
def _decorate_sync(cls, func: t.Callable, schema: t.Any) -> t.Callable:
...
8 changes: 4 additions & 4 deletions flama/pagination/mixins/limit_offset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import inspect
import typing as t

from flama import http, schemas, types
from flama import http, schemas

__all__ = ["LimitOffsetMixin", "LimitOffsetResponse"]

Expand All @@ -23,7 +23,7 @@ class LimitOffsetResponse(http.APIResponse):

def __init__(
self,
schema: type[types.Schema],
schema: t.Any,
offset: t.Optional[t.Union[int, str]] = None,
limit: t.Optional[t.Union[int, str]] = None,
count: t.Optional[bool] = True,
Expand Down Expand Up @@ -59,7 +59,7 @@ class LimitOffsetDecoratorFactory(PaginationDecoratorFactory):
]

@classmethod
def _decorate_async(cls, func: t.Callable, schema: type[types.Schema]) -> t.Callable:
def _decorate_async(cls, func: t.Callable, schema: t.Any) -> t.Callable:
@functools.wraps(func)
async def decorator(
*args,
Expand All @@ -75,7 +75,7 @@ async def decorator(
return decorator

@classmethod
def _decorate_sync(cls, func: t.Callable, schema: type[types.Schema]) -> t.Callable:
def _decorate_sync(cls, func: t.Callable, schema: t.Any) -> t.Callable:
@functools.wraps(func)
def decorator(
*args,
Expand Down
Loading

0 comments on commit de06aec

Please sign in to comment.