diff --git a/flama/ddd/repositories/http.py b/flama/ddd/repositories/http.py index cc0c846c..28cd0bfb 100644 --- a/flama/ddd/repositories/http.py +++ b/flama/ddd/repositories/http.py @@ -5,7 +5,6 @@ import httpx -from flama import types from flama.ddd import exceptions from flama.ddd.repositories import AbstractRepository @@ -34,7 +33,7 @@ def __eq__(self, other): isinstance(other, HTTPResourceManager) and self._client == other._client and self.resource == other.resource ) - async def create(self, data: t.Union[dict[str, t.Any], types.Schema]) -> types.Schema: + async def create(self, data: dict[str, t.Any]) -> dict[str, t.Any]: """Create a new element in the collection. :param data: The data to create the element. @@ -49,9 +48,9 @@ async def create(self, data: t.Union[dict[str, t.Any], types.Schema]) -> types.S raise exceptions.IntegrityError() raise - return types.Schema(response.json()) + return response.json() - async def retrieve(self, id: t.Union[str, uuid.UUID]) -> types.Schema: + async def retrieve(self, id: t.Union[str, uuid.UUID]) -> dict[str, t.Any]: """Retrieve an element from the collection. :param id: The id of the element. @@ -66,9 +65,9 @@ async def retrieve(self, id: t.Union[str, uuid.UUID]) -> types.Schema: raise exceptions.NotFoundError() raise - return types.Schema(response.json()) + return response.json() - async def update(self, id: t.Union[str, uuid.UUID], data: t.Union[dict[str, t.Any], types.Schema]) -> types.Schema: + async def update(self, id: t.Union[str, uuid.UUID], data: dict[str, t.Any]) -> dict[str, t.Any]: """Update an element in the collection. :param id: The id of the element. @@ -86,11 +85,9 @@ async def update(self, id: t.Union[str, uuid.UUID], data: t.Union[dict[str, t.An if e.response.status_code == http.HTTPStatus.BAD_REQUEST: raise exceptions.IntegrityError() raise - return types.Schema(response.json()) + return response.json() - async def partial_update( - self, id: t.Union[str, uuid.UUID], data: t.Union[dict[str, t.Any], types.Schema] - ) -> types.Schema: + async def partial_update(self, id: t.Union[str, uuid.UUID], data: dict[str, t.Any]) -> dict[str, t.Any]: """Partially update an element in the collection. :param id: The id of the element. @@ -108,7 +105,7 @@ async def partial_update( if e.response.status_code == http.HTTPStatus.BAD_REQUEST: raise exceptions.IntegrityError() raise - return types.Schema(response.json()) + return response.json() async def delete(self, id: t.Union[str, uuid.UUID]) -> None: """Delete an element from the collection. @@ -168,7 +165,7 @@ async def _limit_offset_paginated(self) -> t.AsyncIterable[dict[str, t.Any]]: except exceptions.Empty: break - async def list(self, *, pagination: str = "page_number") -> t.AsyncIterable[types.Schema]: + async def list(self, *, pagination: str = "page_number") -> t.AsyncIterable[dict[str, t.Any]]: """List all the elements in the collection. :param pagination: The pagination technique. @@ -178,9 +175,9 @@ async def list(self, *, pagination: str = "page_number") -> t.AsyncIterable[type iterator = self._page_number_paginated() if pagination == "page_number" else self._limit_offset_paginated() async for element in iterator: - yield types.Schema(element) + yield element - async def replace(self, data: builtins.list[dict[str, t.Any]]) -> builtins.list[types.Schema]: + async def replace(self, data: builtins.list[dict[str, t.Any]]) -> builtins.list[dict[str, t.Any]]: """Replace elements in the collection. :param data: The data to replace the elements. @@ -194,9 +191,9 @@ async def replace(self, data: builtins.list[dict[str, t.Any]]) -> builtins.list[ raise exceptions.IntegrityError() raise - return [types.Schema(element) for element in response.json()] + return [element for element in response.json()] - async def partial_replace(self, data: builtins.list[dict[str, t.Any]]) -> builtins.list[types.Schema]: + async def partial_replace(self, data: builtins.list[dict[str, t.Any]]) -> builtins.list[dict[str, t.Any]]: """Partially replace elements in the collection. :param data: The data to replace the elements. @@ -210,7 +207,7 @@ async def partial_replace(self, data: builtins.list[dict[str, t.Any]]) -> builti raise exceptions.IntegrityError() raise - return [types.Schema(element) for element in response.json()] + return [element for element in response.json()] async def drop(self) -> int: """Drop the collection. @@ -238,7 +235,7 @@ def __init__(self, client: "Client"): super().__init__(client) self._resource_manager = HTTPResourceManager(self._resource, client) - async def create(self, data: dict[str, t.Any]) -> types.Schema: + async def create(self, data: dict[str, t.Any]) -> dict[str, t.Any]: """Create a new element in the collection. :param data: The data to create the element. @@ -246,7 +243,7 @@ async def create(self, data: dict[str, t.Any]) -> types.Schema: """ return await self._resource_manager.create(data) - async def retrieve(self, id: uuid.UUID) -> types.Schema: + async def retrieve(self, id: uuid.UUID) -> dict[str, t.Any]: """Retrieve an element from the collection. :param id: The id of the element. @@ -254,7 +251,7 @@ async def retrieve(self, id: uuid.UUID) -> types.Schema: """ return await self._resource_manager.retrieve(id) - async def update(self, id: uuid.UUID, data: dict[str, t.Any]) -> types.Schema: + async def update(self, id: uuid.UUID, data: dict[str, t.Any]) -> dict[str, t.Any]: """Update an element in the collection. :param id: The id of the element. @@ -263,7 +260,7 @@ async def update(self, id: uuid.UUID, data: dict[str, t.Any]) -> types.Schema: """ return await self._resource_manager.update(id, data) - async def partial_update(self, id: uuid.UUID, data: dict[str, t.Any]) -> types.Schema: + async def partial_update(self, id: uuid.UUID, data: dict[str, t.Any]) -> dict[str, t.Any]: """Partially update an element in the collection. :param id: The id of the element. @@ -279,7 +276,7 @@ async def delete(self, id: uuid.UUID) -> None: """ return await self._resource_manager.delete(id) - def list(self, *, pagination: str = "page_number") -> t.AsyncIterable[types.Schema]: + def list(self, *, pagination: str = "page_number") -> t.AsyncIterable[dict[str, t.Any]]: """List all the elements in the collection. :param pagination: The pagination technique. @@ -287,7 +284,7 @@ def list(self, *, pagination: str = "page_number") -> t.AsyncIterable[types.Sche """ return self._resource_manager.list(pagination=pagination) - async def replace(self, data: builtins.list[dict[str, t.Any]]) -> builtins.list[types.Schema]: + async def replace(self, data: builtins.list[dict[str, t.Any]]) -> builtins.list[dict[str, t.Any]]: """Replace elements in the collection. :param data: The data to replace the elements. @@ -295,7 +292,7 @@ async def replace(self, data: builtins.list[dict[str, t.Any]]) -> builtins.list[ """ return await self._resource_manager.replace(data) - async def partial_replace(self, data: builtins.list[dict[str, t.Any]]) -> builtins.list[types.Schema]: + async def partial_replace(self, data: builtins.list[dict[str, t.Any]]) -> builtins.list[dict[str, t.Any]]: """Partially replace elements in the collection. :param data: The data to replace the elements. diff --git a/flama/ddd/repositories/sqlalchemy.py b/flama/ddd/repositories/sqlalchemy.py index c5ccd654..e31c9c57 100644 --- a/flama/ddd/repositories/sqlalchemy.py +++ b/flama/ddd/repositories/sqlalchemy.py @@ -1,6 +1,5 @@ import typing as t -from flama import types from flama.ddd import exceptions from flama.ddd.repositories import AbstractRepository @@ -41,7 +40,7 @@ def __eq__(self, other): and self.table == other.table ) - async def create(self, *data: t.Union[dict[str, t.Any], types.Schema]) -> list[types.Schema]: + async def create(self, *data: dict[str, t.Any]) -> list[dict[str, t.Any]]: """Creates new elements in the table. If the element already exists, it raises an `IntegrityError`. If the element is created, it returns @@ -55,9 +54,9 @@ async def create(self, *data: t.Union[dict[str, t.Any], types.Schema]) -> list[t result = await self._connection.execute(sqlalchemy.insert(self.table).values(data).returning(self.table)) except sqlalchemy.exc.IntegrityError as e: raise exceptions.IntegrityError(str(e)) - return [types.Schema(element._asdict()) for element in result] + return [dict[str, t.Any](element._asdict()) for element in result] - async def retrieve(self, *clauses, **filters) -> types.Schema: + async def retrieve(self, *clauses, **filters) -> dict[str, t.Any]: """Retrieves an element from the table. If the element does not exist, it raises a `NotFoundError`. If more than one element is found, it raises a @@ -82,9 +81,9 @@ async def retrieve(self, *clauses, **filters) -> types.Schema: except sqlalchemy.exc.MultipleResultsFound: raise exceptions.MultipleRecordsError() - return types.Schema(element._asdict()) + return dict[str, t.Any](element._asdict()) - async def update(self, data: t.Union[dict[str, t.Any], types.Schema], *clauses, **filters) -> list[types.Schema]: + async def update(self, data: dict[str, t.Any], *clauses, **filters) -> list[dict[str, t.Any]]: """Updates elements in the table. Using clauses and filters, it filters the elements to update. If no clauses or filters are given, it updates @@ -103,7 +102,7 @@ async def update(self, data: t.Union[dict[str, t.Any], types.Schema], *clauses, except sqlalchemy.exc.IntegrityError: raise exceptions.IntegrityError - return [types.Schema(element._asdict()) for element in result] + return [dict[str, t.Any](element._asdict()) for element in result] async def delete(self, *clauses, **filters) -> None: """Delete an element from the table. @@ -127,7 +126,7 @@ async def delete(self, *clauses, **filters) -> None: async def list( self, *clauses, order_by: t.Optional[str] = None, order_direction: str = "asc", **filters - ) -> t.AsyncIterable[types.Schema]: + ) -> t.AsyncIterable[dict[str, t.Any]]: """Lists all the elements in the table. If no elements are found, it returns an empty list. If no clauses or filters are given, it returns all the @@ -156,7 +155,7 @@ async def list( result = await self._connection.stream(query) async for row in result: - yield types.Schema(row._asdict()) + yield dict[str, t.Any](row._asdict()) async def drop(self, *clauses, **filters) -> int: """Drops elements in the table. @@ -215,7 +214,7 @@ def __init__(self, connection: "AsyncConnection", *args, **kwargs): def __eq__(self, other): return isinstance(other, SQLAlchemyTableRepository) and self._table == other._table and super().__eq__(other) - async def create(self, *data: t.Union[dict[str, t.Any], types.Schema]) -> list[types.Schema]: + async def create(self, *data: dict[str, t.Any]) -> list[dict[str, t.Any]]: """Creates new elements in the repository. If the element already exists, it raises an `exceptions.IntegrityError`. If the element is created, it returns @@ -227,7 +226,7 @@ async def create(self, *data: t.Union[dict[str, t.Any], types.Schema]) -> list[t """ return await self._table_manager.create(*data) - async def retrieve(self, *clauses, **filters) -> types.Schema: + async def retrieve(self, *clauses, **filters) -> dict[str, t.Any]: """Retrieves an element from the repository. If the element does not exist, it raises a `NotFoundError`. If more than one element is found, it raises a @@ -247,7 +246,7 @@ async def retrieve(self, *clauses, **filters) -> types.Schema: """ return await self._table_manager.retrieve(*clauses, **filters) - async def update(self, data: t.Union[dict[str, t.Any], types.Schema], *clauses, **filters) -> list[types.Schema]: + async def update(self, data: dict[str, t.Any], *clauses, **filters) -> list[dict[str, t.Any]]: """Updates an element in the repository. If the element does not exist, it raises a `NotFoundError`. If the element is updated, it returns the updated @@ -280,7 +279,7 @@ async def delete(self, *clauses, **filters) -> None: def list( self, *clauses, order_by: t.Optional[str] = None, order_direction: str = "asc", **filters - ) -> t.AsyncIterable[types.Schema]: + ) -> t.AsyncIterable[dict[str, t.Any]]: """Lists all the elements in the repository. Lists all the elements in the repository that match the clauses and filters. If no clauses or filters are given, diff --git a/flama/http.py b/flama/http.py index 12b7a252..926a2412 100644 --- a/flama/http.py +++ b/flama/http.py @@ -115,9 +115,6 @@ async def __call__( # type: ignore[override] await super().__call__(scope, receive, send) # type: ignore[arg-type] def render(self, content: t.Any) -> bytes: - if isinstance(content, types.Schema): - content = dict(content) - return json.dumps( content, ensure_ascii=False, allow_nan=False, indent=None, separators=(",", ":"), cls=EnhancedJSONEncoder ).encode("utf-8") @@ -147,7 +144,7 @@ async def __call__( # type: ignore[override] class APIResponse(JSONResponse): media_type = "application/json" - def __init__(self, content: t.Any = None, schema: t.Optional[type["types.Schema"]] = None, *args, **kwargs): + def __init__(self, content: t.Any = None, schema: t.Any = None, *args, **kwargs): self.schema = schema super().__init__(content, *args, **kwargs) @@ -182,7 +179,11 @@ def __init__( } super().__init__( - content, schema=types.Schema[schemas.schemas.APIError], status_code=status_code, *args, **kwargs + content, + schema=t.Annotated[schemas.Schema, schemas.SchemaMetadata(schemas.schemas.APIError)], + status_code=status_code, + *args, + **kwargs, ) self.detail = detail diff --git a/flama/models/resource.py b/flama/models/resource.py index 89b39673..8488975a 100644 --- a/flama/models/resource.py +++ b/flama/models/resource.py @@ -2,7 +2,7 @@ import typing as t import flama.schemas -from flama import types +from flama import schemas from flama.models.components import ModelComponentBuilder from flama.resources import data_structures from flama.resources.exceptions import ResourceAttributeError @@ -46,9 +46,9 @@ def _add_predict(cls, name: str, verbose_name: str, model_model_type: type["Mode async def predict( self, model: model_model_type, # type: ignore[valid-type] - data: types.Schema[flama.schemas.schemas.MLModelInput], # type: ignore[type-arg] - ) -> types.Schema[flama.schemas.schemas.MLModelOutput]: # type: ignore[type-arg] - return types.Schema({"output": model.predict(data["input"])}) # type: ignore[attr-defined] + data: t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(flama.schemas.schemas.MLModelInput)], + ) -> t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(flama.schemas.schemas.MLModelOutput)]: + return {"output": model.predict(data["input"])} predict.__doc__ = f""" tags: diff --git a/flama/pagination/decorators.py b/flama/pagination/decorators.py index 3a5ccf1a..75061be7 100644 --- a/flama/pagination/decorators.py +++ b/flama/pagination/decorators.py @@ -3,14 +3,12 @@ import inspect import typing as t -from flama import types - class PaginationDecoratorFactory: PARAMETERS: list[inspect.Parameter] @classmethod - def decorate(cls, func: t.Callable, schema: type[types.Schema]) -> t.Callable: + def decorate(cls, func: t.Callable, schema: t.Any) -> t.Callable: func_signature = inspect.signature(func) if "kwargs" not in func_signature.parameters: raise TypeError("Paginated views must define **kwargs param") @@ -24,17 +22,17 @@ def decorate(cls, func: t.Callable, schema: type[types.Schema]) -> t.Callable: *[v for k, v in func_signature.parameters.items() if k != "kwargs"], *cls.PARAMETERS, ], - return_annotation=types.Schema[schema], # type: ignore + return_annotation=schema, ) return decorated_func @classmethod @abc.abstractmethod - def _decorate_async(cls, func: t.Callable, schema: type[types.Schema]) -> t.Callable: + def _decorate_async(cls, func: t.Callable, schema: t.Any) -> t.Callable: ... @classmethod @abc.abstractmethod - def _decorate_sync(cls, func: t.Callable, schema: type[types.Schema]) -> t.Callable: + def _decorate_sync(cls, func: t.Callable, schema: t.Any) -> t.Callable: ... diff --git a/flama/pagination/mixins/limit_offset.py b/flama/pagination/mixins/limit_offset.py index f84b18c9..7395edfe 100644 --- a/flama/pagination/mixins/limit_offset.py +++ b/flama/pagination/mixins/limit_offset.py @@ -2,7 +2,7 @@ import inspect import typing as t -from flama import http, schemas, types +from flama import http, schemas __all__ = ["LimitOffsetMixin", "LimitOffsetResponse"] @@ -23,7 +23,7 @@ class LimitOffsetResponse(http.APIResponse): def __init__( self, - schema: type[types.Schema], + schema: t.Any, offset: t.Optional[t.Union[int, str]] = None, limit: t.Optional[t.Union[int, str]] = None, count: t.Optional[bool] = True, @@ -59,7 +59,7 @@ class LimitOffsetDecoratorFactory(PaginationDecoratorFactory): ] @classmethod - def _decorate_async(cls, func: t.Callable, schema: type[types.Schema]) -> t.Callable: + def _decorate_async(cls, func: t.Callable, schema: t.Any) -> t.Callable: @functools.wraps(func) async def decorator( *args, @@ -75,7 +75,7 @@ async def decorator( return decorator @classmethod - def _decorate_sync(cls, func: t.Callable, schema: type[types.Schema]) -> t.Callable: + def _decorate_sync(cls, func: t.Callable, schema: t.Any) -> t.Callable: @functools.wraps(func) def decorator( *args, diff --git a/flama/pagination/mixins/page_number.py b/flama/pagination/mixins/page_number.py index d06520ef..81567747 100644 --- a/flama/pagination/mixins/page_number.py +++ b/flama/pagination/mixins/page_number.py @@ -2,7 +2,7 @@ import inspect import typing as t -from flama import http, schemas, types +from flama import http, schemas __all__ = ["PageNumberMixin", "PageNumberResponse"] @@ -25,7 +25,7 @@ class PageNumberResponse(http.APIResponse): def __init__( self, - schema: type[types.Schema], + schema: t.Any, page: t.Optional[t.Union[int, str]] = None, page_size: t.Optional[t.Union[int, str]] = None, count: t.Optional[bool] = True, @@ -66,7 +66,7 @@ class PageNumberDecoratorFactory(PaginationDecoratorFactory): ] @classmethod - def _decorate_async(cls, func: t.Callable, schema: type[types.Schema]) -> t.Callable: + def _decorate_async(cls, func: t.Callable, schema: t.Any) -> t.Callable: @functools.wraps(func) async def decorator( *args, @@ -82,7 +82,7 @@ async def decorator( return decorator @classmethod - def _decorate_sync(cls, func: t.Callable, schema: type[types.Schema]) -> t.Callable: + def _decorate_sync(cls, func: t.Callable, schema: t.Any) -> t.Callable: @functools.wraps(func) def decorator( *args, diff --git a/flama/resources/crud.py b/flama/resources/crud.py index e1be1358..5b6c2a48 100644 --- a/flama/resources/crud.py +++ b/flama/resources/crud.py @@ -1,7 +1,7 @@ import typing as t from http import HTTPStatus -from flama import exceptions, http, schemas, types +from flama import exceptions, http, schemas from flama.ddd import exceptions as ddd_exceptions from flama.resources import data_structures from flama.resources.rest import RESTResource, RESTResourceType @@ -25,8 +25,8 @@ def _add_create( async def create( self, worker: FlamaWorker, - resource: types.Schema[rest_schemas.input.schema], # type: ignore - ) -> types.Schema[rest_schemas.output.schema]: # type: ignore + resource: t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(rest_schemas.input.schema)], + ) -> t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(rest_schemas.output.schema)]: if resource.get(rest_model.primary_key.name) is None: resource.pop(rest_model.primary_key.name, None) @@ -77,7 +77,7 @@ async def retrieve( self, worker: FlamaWorker, resource_id: rest_model.primary_key.type, # type: ignore - ) -> types.Schema[rest_schemas.output.schema]: # type: ignore + ) -> t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(rest_schemas.output.schema)]: # type: ignore try: async with worker: repository = worker.repositories[self._meta.name] @@ -119,8 +119,8 @@ async def update( self, worker: FlamaWorker, resource_id: rest_model.primary_key.type, # type: ignore - resource: types.Schema[rest_schemas.input.schema], # type: ignore - ) -> types.Schema[rest_schemas.output.schema]: # type: ignore + resource: t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(rest_schemas.input.schema)], + ) -> t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(rest_schemas.output.schema)]: resource[rest_model.primary_key.name] = resource_id async with worker: try: @@ -134,7 +134,7 @@ async def update( except ddd_exceptions.IntegrityError: raise exceptions.HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail="Wrong input data") - return types.Schema[rest_schemas.output.schema](result[0]) + return result[0] update.__doc__ = f""" tags: @@ -173,8 +173,8 @@ async def partial_update( self, worker: FlamaWorker, resource_id: rest_model.primary_key.type, # type: ignore - resource: types.PartialSchema[rest_schemas.input.schema], # type: ignore - ) -> types.Schema[rest_schemas.output.schema]: # type: ignore + resource: t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(rest_schemas.input.schema, partial=True)], + ) -> t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(rest_schemas.output.schema)]: resource[rest_model.primary_key.name] = resource_id async with worker: repository = worker.repositories[self._meta.name] @@ -186,7 +186,7 @@ async def partial_update( if not result: raise exceptions.HTTPException(status_code=HTTPStatus.NOT_FOUND) - return types.Schema[rest_schemas.output.schema](result[0]) + return result[0] partial_update.__doc__ = f""" tags: @@ -256,7 +256,7 @@ async def list( order_by: t.Optional[str] = None, order_direction: str = "asc", **kwargs, - ) -> types.Schema[rest_schemas.output.schema]: # type: ignore + ) -> t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(rest_schemas.output.schema)]: async with worker: repository = worker.repositories[self._meta.name] return [ # type: ignore[return-value] @@ -293,8 +293,8 @@ def _add_replace( async def replace( self, worker: FlamaWorker, - resources: list[types.Schema[rest_schemas.input.schema]], # type: ignore - ) -> list[types.Schema[rest_schemas.output.schema]]: # type: ignore + resources: t.Annotated[list[schemas.SchemaType], schemas.SchemaMetadata(rest_schemas.input.schema)], + ) -> t.Annotated[list[schemas.SchemaType], schemas.SchemaMetadata(rest_schemas.output.schema)]: async with worker: repository = worker.repositories[self._meta.name] await repository.drop() @@ -336,8 +336,8 @@ def _add_partial_replace( async def partial_replace( self, worker: FlamaWorker, - resources: list[types.Schema[rest_schemas.input.schema]], # type: ignore - ) -> list[types.Schema[rest_schemas.output.schema]]: # type: ignore + resources: t.Annotated[list[schemas.SchemaType], schemas.SchemaMetadata(rest_schemas.input.schema)], + ) -> t.Annotated[list[schemas.SchemaType], schemas.SchemaMetadata(rest_schemas.output.schema)]: async with worker: repository = worker.repositories[self._meta.name] await repository.drop( @@ -373,7 +373,9 @@ class DropMixin: @classmethod def _add_drop(cls, name: str, verbose_name: str, **kwargs) -> dict[str, t.Any]: @resource_method("/", methods=["DELETE"], name="drop") - async def drop(self, worker: FlamaWorker) -> types.Schema[schemas.schemas.DropCollection]: # type: ignore + async def drop( + self, worker: FlamaWorker + ) -> t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(schemas.schemas.DropCollection)]: async with worker: repository = worker.repositories[self._meta.name] result = await repository.drop() diff --git a/flama/routing.py b/flama/routing.py index c042bc36..7470afd1 100644 --- a/flama/routing.py +++ b/flama/routing.py @@ -177,7 +177,7 @@ def _build_api_response(self, handler: t.Callable, response: t.Union[http.Respon :param response: The current response. :return: An API response. """ - if isinstance(response, (types.Schema, dict, list)): + if isinstance(response, (dict, list)): try: schema = schemas.Schema.from_type(inspect.signature(handler).return_annotation).unique_schema except Exception: diff --git a/flama/schemas/__init__.py b/flama/schemas/__init__.py index b126eed8..46771c49 100644 --- a/flama/schemas/__init__.py +++ b/flama/schemas/__init__.py @@ -5,6 +5,14 @@ from flama.schemas.data_structures import Field, Parameter, Schema from flama.schemas.exceptions import SchemaParseError, SchemaValidationError +from flama.schemas.types import ( + SchemaMetadata, + SchemaType, + get_schema_metadata, + is_schema, + is_schema_multiple, + is_schema_partial, +) if t.TYPE_CHECKING: from flama.schemas.adapter import Adapter @@ -18,6 +26,12 @@ "fields", "lib", "schemas", + "SchemaMetadata", + "SchemaType", + "get_schema_metadata", + "is_schema", + "is_schema_multiple", + "is_schema_partial", ] adapter: "Adapter" diff --git a/flama/schemas/adapter.py b/flama/schemas/adapter.py index bf03b77d..aba52d60 100644 --- a/flama/schemas/adapter.py +++ b/flama/schemas/adapter.py @@ -2,8 +2,8 @@ import sys import typing as t +from flama.schemas.types import _T_Field, _T_Schema from flama.types import JSONSchema -from flama.types.schema import _T_Field, _T_Schema if sys.version_info < (3, 10): # PORT: Remove when stop supporting 3.9 # pragma: no cover from typing_extensions import TypeGuard diff --git a/flama/schemas/data_structures.py b/flama/schemas/data_structures.py index fc02c500..023ff944 100644 --- a/flama/schemas/data_structures.py +++ b/flama/schemas/data_structures.py @@ -50,7 +50,7 @@ class Field: def __post_init__(self) -> None: object.__setattr__(self, "nullable", type(None) in t.get_args(self.type) or self.default is None) - field_type = t.get_args(self.type)[0] if t.get_origin(self.type) in (list, t.Union) else self.type + field_type = t.get_args(self.type)[0] if t.get_origin(self.type) in UnionType else self.type if not Schema.is_schema(field_type) and self.multiple is None: object.__setattr__(self, "multiple", t.get_origin(self.type) is list) @@ -106,16 +106,13 @@ class Schema: @classmethod def from_type(cls, type_: t.Optional[type]) -> "Schema": - if types.Schema.is_schema(type_): - schema = ( - type_.schema - if not type_.partial - else schemas.adapter.build_schema( - name=schemas.adapter.name(type_.schema, prefix="Partial").rsplit(".", 1)[1], - schema=type_.schema, - partial=True, + if schemas.is_schema(type_): + schema = schemas.get_schema_metadata(type_).schema + + if schemas.is_schema_partial(type_): + schema = schemas.adapter.build_schema( + name=schemas.adapter.name(schema, prefix="Partial").rsplit(".", 1)[1], schema=schema, partial=True ) - ) elif t.get_origin(type_) in (list, tuple, set): return cls.from_type(t.get_args(type_)[0]) else: @@ -255,7 +252,7 @@ def dump(self, values): class Parameter: name: str location: ParameterLocation - type: type + type: t.Any required: bool = True default: t.Any = InjectionParameter.empty nullable: bool = dataclasses.field(init=False) diff --git a/flama/schemas/generator.py b/flama/schemas/generator.py index 8dc6c16f..2aa60aa9 100644 --- a/flama/schemas/generator.py +++ b/flama/schemas/generator.py @@ -31,7 +31,7 @@ class EndpointInfo: @dataclasses.dataclass(frozen=True) class SchemaInfo: name: str - schema: types.Schema + schema: t.Any @property def ref(self) -> str: diff --git a/flama/schemas/routing.py b/flama/schemas/routing.py index 0d70b312..85c8cd70 100644 --- a/flama/schemas/routing.py +++ b/flama/schemas/routing.py @@ -1,7 +1,7 @@ import inspect import typing as t -from flama import types +from flama import schemas from flama.injection.resolver import Return from flama.schemas.data_structures import Field, Parameter, Parameters @@ -68,7 +68,7 @@ def body(self) -> dict[str, t.Optional[Parameter]]: ( Parameter.build("body", p) for p in parameters - if (types.Schema.is_schema(p.annotation) or t.get_origin(p.annotation) == list) + if (schemas.is_schema(p.annotation) or t.get_origin(p.annotation) == list) and p.name not in self._route.path.parameters ), None, diff --git a/flama/schemas/types.py b/flama/schemas/types.py new file mode 100644 index 00000000..e89cbf4e --- /dev/null +++ b/flama/schemas/types.py @@ -0,0 +1,41 @@ +import dataclasses +import typing as t + +__all__ = [ + "_T_Field", + "_T_Schema", + "SchemaType", + "SchemaMetadata", + "get_schema_metadata", + "is_schema", + "is_schema_partial", + "is_schema_multiple", +] + +_T_Field = t.TypeVar("_T_Field") +_T_Schema = t.TypeVar("_T_Schema") + +SchemaType = dict[str, t.Any] + + +@dataclasses.dataclass(frozen=True) +class SchemaMetadata: + schema: t.Any + partial: bool = False + multiple: bool = False + + +def get_schema_metadata(obj: t.Any) -> SchemaMetadata: + return getattr(obj, "__metadata__", [None])[0] + + +def is_schema(obj: t.Any) -> bool: + return isinstance(get_schema_metadata(obj), SchemaMetadata) + + +def is_schema_partial(obj: t.Any) -> bool: + return is_schema(obj) and get_schema_metadata(obj).partial + + +def is_schema_multiple(obj: t.Any) -> bool: + return is_schema(obj) and get_schema_metadata(obj).multiple diff --git a/flama/types/__init__.py b/flama/types/__init__.py index bec66f2d..6a49bdd1 100644 --- a/flama/types/__init__.py +++ b/flama/types/__init__.py @@ -2,5 +2,4 @@ from flama.types.asgi import * # noqa from flama.types.http import * # noqa from flama.types.json import * # noqa -from flama.types.schema import * # noqa from flama.types.websockets import * # noqa diff --git a/flama/types/schema.py b/flama/types/schema.py deleted file mode 100644 index 22f160c7..00000000 --- a/flama/types/schema.py +++ /dev/null @@ -1,48 +0,0 @@ -import inspect -import sys -import typing as t - -if sys.version_info < (3, 10): # PORT: Remove when stop supporting 3.9 # pragma: no cover - from typing_extensions import TypeGuard - - t.TypeGuard = TypeGuard # type: ignore - -__all__ = ["_T_Field", "_T_Schema", "Schema", "PartialSchema"] - -_T_Field = t.TypeVar("_T_Field") -_T_Schema = t.TypeVar("_T_Schema") - - -def _is_schema( - obj: t.Any, -) -> t.TypeGuard[type["Schema"]]: # type: ignore # PORT: Remove this comment when stop supporting 3.9 - return inspect.isclass(obj) and issubclass(obj, Schema) - - -class _SchemaMeta(type): - def __eq__(self, other) -> bool: - return _is_schema(other) and self.schema == other.schema # type: ignore[attr-defined] - - def __hash__(self) -> int: - return id(self) - - -class Schema(dict, t.Generic[_T_Schema], metaclass=_SchemaMeta): # type: ignore[misc] - schema: t.ClassVar[t.Any] = None - partial: bool = False - - def __class_getitem__(cls, schema_cls: _T_Schema): # type: ignore[override] - return _SchemaMeta("_SchemaAlias", (Schema,), {"schema": schema_cls}) # type: ignore[return-value] - - @staticmethod - def is_schema( - obj: t.Any, - ) -> t.TypeGuard[type["Schema"]]: # type: ignore # PORT: Remove this comment when stop supporting 3.9 - return _is_schema(obj) - - -class PartialSchema(Schema, t.Generic[_T_Schema]): - partial = True - - def __class_getitem__(cls, schema_cls: _T_Schema): # type: ignore[override] - return _SchemaMeta("_SchemaAlias", (PartialSchema,), {"schema": schema_cls}) # type: ignore[return-value] diff --git a/flama/validation.py b/flama/validation.py index dfaa772f..17edaefe 100644 --- a/flama/validation.py +++ b/flama/validation.py @@ -1,6 +1,6 @@ import typing as t -from flama import codecs, exceptions, http, types +from flama import codecs, exceptions, http, schemas, types from flama.injection import Component, Components from flama.injection.resolver import Parameter from flama.negotiation import ContentTypeNegotiator, WebSocketEncodingNegotiator @@ -102,7 +102,7 @@ def can_handle_parameter(self, parameter: Parameter): schema = ( t.get_args(parameter.annotation)[0] if t.get_origin(parameter.annotation) == list else parameter.annotation ) - return types.Schema.is_schema(schema) + return schemas.is_schema(schema) def resolve(self, parameter: Parameter, request: http.Request, route: BaseRoute, data: types.RequestData): body_param = route.parameters.body[request.method] @@ -112,9 +112,7 @@ def resolve(self, parameter: Parameter, request: http.Request, route: BaseRoute, ), f"Body schema parameter not defined for route '{route}' and method '{request.method}'" try: - return body_param.schema.validate( - data, partial=types.Schema.is_schema(parameter.annotation) and parameter.annotation.partial - ) + return body_param.schema.validate(data, partial=schemas.is_schema_partial(parameter.annotation)) except SchemaValidationError as exc: # noqa: safety net, just should not happen raise exceptions.ValidationError(detail=exc.errors) diff --git a/tests/schemas/test_data_structures.py b/tests/schemas/test_data_structures.py index ec5090ec..a6ccd88e 100644 --- a/tests/schemas/test_data_structures.py +++ b/tests/schemas/test_data_structures.py @@ -1,6 +1,5 @@ import datetime import functools -import inspect import typing as t import uuid from copy import deepcopy @@ -10,11 +9,13 @@ import pytest import typesystem -from flama import types +from flama import schemas, types from flama.injection import Parameter as InjectionParameter from flama.schemas.data_structures import Field, Parameter, ParameterLocation, Schema from tests.schemas.test_generator import assert_recursive_contains +Unknown = t.NewType("Unknown", None) + class TestCaseField: @pytest.mark.parametrize( @@ -113,23 +114,23 @@ def schema_type( # noqa: C901 elif request.param == "bare_schema": return foo_schema.schema elif request.param == "schema": - return types.Schema[foo_schema.schema] + return t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(foo_schema.schema)] elif request.param == "list_of_schema": - return list[types.Schema[foo_schema.schema]] if inspect.isclass(foo_schema.schema) else foo_schema.schema + return t.Annotated[list[schemas.SchemaType], schemas.SchemaMetadata(foo_schema.schema)] elif request.param == "schema_partial": if app.schema.schema_library.lib in (typesystem,): pytest.skip("Library does not support optional partial schemas") - return types.PartialSchema[foo_schema.schema] + return t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(foo_schema.schema, partial=True)] elif request.param == "schema_nested": - return types.Schema[bar_schema.schema] + return t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(bar_schema.schema)] elif request.param == "schema_nested_optional": if app.schema.schema_library.lib in (typesystem, marshmallow): pytest.skip("Library does not support optional nested schemas") - return types.Schema[bar_optional_schema.schema] + return t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(bar_optional_schema.schema)] elif request.param == "schema_nested_list": - return types.Schema[bar_list_schema.schema] + return t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(bar_list_schema.schema)] elif request.param == "schema_nested_dict": - return types.Schema[bar_dict_schema.schema] + return t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(bar_dict_schema.schema)] else: raise ValueError("Wrong schema type") @@ -351,25 +352,29 @@ class TestCaseParameter: ), pytest.param( "body", - InjectionParameter("foo", types.Schema), - Parameter(name="foo", location=ParameterLocation.body, type=types.Schema), + InjectionParameter("foo", Unknown), + Parameter(name="foo", location=ParameterLocation.body, type=Unknown), id="body", ), pytest.param( "response", - InjectionParameter("foo", types.Schema), - Parameter(name="foo", location=ParameterLocation.response, type=types.Schema), + InjectionParameter("foo", Unknown), + Parameter(name="foo", location=ParameterLocation.response, type=Unknown), id="response", ), ), ) def test_build(self, foo_schema, type_, parameter, result): - if parameter.annotation == types.Schema: - parameter = InjectionParameter(parameter.name, types.Schema[foo_schema.schema], parameter.default) + if parameter.annotation == Unknown: + parameter = InjectionParameter( + parameter.name, + t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(foo_schema.schema)], + parameter.default, + ) result = Parameter( name=result.name, location=result.location, - type=types.Schema[foo_schema.schema], + type=t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(foo_schema.schema)], required=result.required, default=result.default, ) diff --git a/tests/schemas/test_generator.py b/tests/schemas/test_generator.py index b56e143b..6b9927e1 100644 --- a/tests/schemas/test_generator.py +++ b/tests/schemas/test_generator.py @@ -1,4 +1,5 @@ import contextlib +import typing as t from collections import namedtuple import marshmallow @@ -7,7 +8,7 @@ import typesystem import typesystem.fields -from flama import types +from flama import schemas from flama.endpoints import HTTPEndpoint from flama.routing import Router from flama.schemas import openapi @@ -899,7 +900,7 @@ def schemas(self, owner_schema, puppy_schema, body_param_schema): def add_endpoints(self, app, puppy_schema, body_param_schema): # noqa: C901 @app.route("/endpoint/", methods=["GET"]) class PuppyEndpoint(HTTPEndpoint): - async def get(self) -> types.Schema[puppy_schema.schema]: + async def get(self) -> t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(puppy_schema.schema)]: """ description: Endpoint. responses: @@ -909,7 +910,7 @@ async def get(self) -> types.Schema[puppy_schema.schema]: return {"name": "Canna"} @app.route("/custom-component/", methods=["GET"]) - async def get() -> types.Schema[puppy_schema.schema]: + async def get() -> t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(puppy_schema.schema)]: """ description: Custom component. responses: @@ -919,7 +920,9 @@ async def get() -> types.Schema[puppy_schema.schema]: return {"name": "Canna"} @app.route("/many-components/", methods=["GET"]) - async def many_components() -> types.Schema[puppy_schema.schema]: + async def many_components() -> t.Annotated[ + list[schemas.SchemaType], schemas.SchemaMetadata(puppy_schema.schema) + ]: """ description: Many custom components. responses: @@ -929,7 +932,7 @@ async def many_components() -> types.Schema[puppy_schema.schema]: return [{"name": "foo"}, {"name": "bar"}] @app.route("/query-param/", methods=["GET"]) - async def query_param(param1: int, param2: str = None, param3: bool = True): + async def query_param(param1: int, param2: t.Optional[str] = None, param3: bool = True): """ description: Query param. responses: @@ -949,7 +952,7 @@ async def path_param(param: int): return {"name": param} @app.route("/body-param/", methods=["POST"]) - async def body_param(param: types.Schema[body_param_schema.schema]): + async def body_param(param: t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(body_param_schema.schema)]): """ description: Body param. responses: diff --git a/tests/schemas/test_routing.py b/tests/schemas/test_routing.py index d963e58a..9fed680d 100644 --- a/tests/schemas/test_routing.py +++ b/tests/schemas/test_routing.py @@ -7,7 +7,7 @@ import typesystem.fields import flama.types.websockets -from flama import Component, HTTPEndpoint, Route, WebSocketEndpoint, WebSocketRoute, types, websockets +from flama import Component, HTTPEndpoint, Route, WebSocketEndpoint, WebSocketRoute, schemas, websockets from flama.schemas.data_structures import Parameter, ParameterLocation @@ -50,8 +50,12 @@ def route(self, request, foo_schema): if request.param == "http_function": def foo( - w: int, a: Custom, z: types.Schema[foo_schema], x: int = 1, y: t.Optional[str] = None - ) -> types.Schema[foo_schema]: + w: int, + a: Custom, + z: t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(foo_schema)], + x: int = 1, + y: t.Optional[str] = None, + ) -> t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(foo_schema)]: ... return Route("/foo/{w:int}/", endpoint=foo, methods=["GET"]) @@ -60,8 +64,13 @@ def foo( class BarEndpoint(HTTPEndpoint): def get( - self, w: int, a: Custom, z: types.Schema[foo_schema], x: int = 1, y: t.Optional[str] = None - ) -> types.Schema[foo_schema]: + self, + w: int, + a: Custom, + z: t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(foo_schema)], + x: int = 1, + y: t.Optional[str] = None, + ) -> t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(foo_schema)]: ... return Route("/bar/{w:int}/", endpoint=BarEndpoint, methods=["GET"]) @@ -73,7 +82,7 @@ def foo( data: flama.types.websockets.Data, w: int, a: Custom, - z: types.Schema[foo_schema], + z: t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(foo_schema)], x: int = 1, y: t.Optional[str] = None, ) -> None: @@ -90,7 +99,7 @@ def on_receive( data: flama.types.websockets.Data, w: int, a: Custom, - z: types.Schema[foo_schema], + z: t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(foo_schema)], x: int = 1, y: t.Optional[str] = None, ) -> None: @@ -336,7 +345,9 @@ def test_path(self, route, expected_params): ) def test_body(self, route, expected_params, foo_schema): expected_params = { - k: Parameter(**{**param, "type": types.Schema[foo_schema]}) if param else None + k: Parameter(**{**param, "type": t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(foo_schema)]}) + if param + else None for k, param in expected_params.items() } assert route.parameters.body == expected_params @@ -381,7 +392,14 @@ def test_body(self, route, expected_params, foo_schema): ) def test_response(self, route, expected_params, foo_schema): expected_params = { - k: Parameter(**{**param, "type": types.Schema[foo_schema] if param["type"] else None}) + k: Parameter( + **{ + **param, + "type": t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(foo_schema)] + if param["type"] + else None, + } + ) for k, param in expected_params.items() } assert route.parameters.response == expected_params diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index 3b788210..616a1103 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -1,3 +1,4 @@ +import typing as t import warnings from unittest.mock import AsyncMock, MagicMock, PropertyMock, call, patch @@ -8,7 +9,7 @@ import typesystem import typesystem.fields -from flama import Component, Flama, exceptions, types, websockets +from flama import Component, Flama, exceptions, schemas, types, websockets from flama.endpoints import HTTPEndpoint, WebSocketEndpoint @@ -56,11 +57,13 @@ def puppy_schema(self, app): def puppy_endpoint(self, app, puppy_schema): @app.route("/puppy/") class PuppyEndpoint(HTTPEndpoint): - def get(self, puppy: Puppy) -> types.Schema[puppy_schema]: - return types.Schema({"name": puppy.name}) + def get(self, puppy: Puppy) -> t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(puppy_schema)]: + return {"name": puppy.name} - async def post(self, puppy: types.Schema[puppy_schema]) -> types.Schema[puppy_schema]: - return types.Schema(puppy) + async def post( + self, puppy: t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(puppy_schema)] + ) -> t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(puppy_schema)]: + return puppy return PuppyEndpoint diff --git a/tests/test_http.py b/tests/test_http.py index c9178288..b13d5ea4 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -13,7 +13,7 @@ import typesystem import typesystem.fields -from flama import exceptions, http, types +from flama import exceptions, http @dataclasses.dataclass @@ -87,7 +87,7 @@ def schema(self): ), pytest.param({"foo": Exception}, {"foo": "Exception"}, None, id="exception_class"), pytest.param({"foo": Exception("bar")}, {"foo": "Exception('bar')"}, None, id="exception_obj"), - pytest.param(types.Schema({"foo": "bar"}), {"foo": "bar"}, None, id="schema"), + pytest.param({"foo": "bar"}, {"foo": "bar"}, None, id="schema"), pytest.param({"foo": Foo(bar=1)}, {"foo": {"bar": 1}}, None, id="dataclass"), pytest.param({"foo": Mock()}, None, TypeError, id="error"), ), @@ -132,7 +132,7 @@ def test_init(self, schema): ) def test_render(self, schema, use_schema, content, expected, exception): with exception: - response = http.APIResponse(schema=types.Schema[schema] if use_schema else None, content=content) + response = http.APIResponse(schema=schema if use_schema else None, content=content) assert response.body.decode() == expected diff --git a/tests/test_pagination.py b/tests/test_pagination.py index e35bdd2c..0f95a51b 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -8,7 +8,7 @@ import typesystem.fields from pytest import param -from flama import types +from flama import schemas from flama.pagination import paginator from tests.asserts import assert_recursive_contains @@ -40,7 +40,9 @@ class TestCasePageNumberPagination: @pytest.fixture(scope="function", autouse=True) def add_endpoints(self, app, output_schema): @app.route("/page-number/", methods=["GET"], pagination="page_number") - def page_number(**kwargs) -> types.Schema[output_schema.schema]: + def page_number( + **kwargs, + ) -> t.Annotated[list[schemas.SchemaType], schemas.SchemaMetadata(output_schema.schema)]: return [{"value": i} for i in range(25)] def test_registered_schemas(self, app, output_schema): @@ -58,7 +60,7 @@ def test_invalid_view(self, output_schema): with pytest.raises(TypeError, match=r"Paginated views must define \*\*kwargs param"): @paginator._paginate_page_number - def invalid() -> types.Schema[output_schema.schema]: + def invalid() -> t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(output_schema.schema)]: ... def test_invalid_response(self): @@ -123,7 +125,9 @@ def test_pagination_schema_return(self, app, output_schema): async def test_async_function(self, app, client, output_schema): @app.route("/page-number-async/", methods=["GET"], pagination="page_number") - async def page_number_async(**kwargs) -> types.Schema[output_schema.schema]: + async def page_number_async( + **kwargs, + ) -> t.Annotated[list[schemas.SchemaType], schemas.SchemaMetadata(output_schema.schema)]: return [{"value": i} for i in range(25)] response = await client.get("/page-number-async/") @@ -178,7 +182,9 @@ class TestCaseLimitOffsetPagination: @pytest.fixture(scope="function", autouse=True) def add_endpoints(self, app, output_schema): @app.route("/limit-offset/", methods=["GET"], pagination="limit_offset") - def limit_offset(**kwargs) -> types.Schema[output_schema.schema]: + def limit_offset( + **kwargs, + ) -> t.Annotated[list[schemas.SchemaType], schemas.SchemaMetadata(output_schema.schema)]: return [{"value": i} for i in range(25)] def test_registered_schemas(self, app, output_schema): @@ -196,7 +202,7 @@ def test_invalid_view(self, output_schema): with pytest.raises(TypeError, match=r"Paginated views must define \*\*kwargs param"): @paginator._paginate_limit_offset - def invalid() -> types.Schema[output_schema.schema]: + def invalid() -> t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(output_schema.schema)]: ... def test_invalid_response(self): @@ -262,7 +268,9 @@ def test_pagination_schema_return(self, app, output_schema): async def test_async_function(self, app, client, output_schema): @app.route("/limit-offset-async/", methods=["GET"], pagination="limit_offset") - async def limit_offset_async(**kwargs) -> types.Schema[output_schema.schema]: + async def limit_offset_async( + **kwargs, + ) -> t.Annotated[list[schemas.SchemaType], schemas.SchemaMetadata(output_schema.schema)]: return [{"value": i} for i in range(25)] response = await client.get("/limit-offset-async/") diff --git a/tests/validation/test_return.py b/tests/validation/test_return.py index 220e31db..22dd2cc8 100644 --- a/tests/validation/test_return.py +++ b/tests/validation/test_return.py @@ -1,10 +1,12 @@ +import typing as t + import marshmallow import pydantic import pytest import typesystem import typesystem.fields -from flama import endpoints, http, types +from flama import endpoints, http, schemas, types class TestCaseReturnValidation: @@ -59,17 +61,17 @@ class Dummy: @app.route("/return-schema/", methods=["GET"]) class ReturnSchemaHTTPEndpoint(endpoints.HTTPEndpoint): - async def get(self) -> types.Schema[output_schema]: + async def get(self) -> t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(output_schema)]: return {"name": "Canna"} @app.route("/return-schema-many/", methods=["GET"]) class ReturnSchemaManyHTTPEndpoint(endpoints.HTTPEndpoint): - async def get(self) -> types.Schema[output_schema]: + async def get(self) -> t.Annotated[list[schemas.SchemaType], schemas.SchemaMetadata(output_schema)]: return [{"name": "Canna"}, {"name": "Sandy"}] @app.route("/return-schema-empty/", methods=["GET"]) class ReturnSchemaEmptyHTTPEndpoint(endpoints.HTTPEndpoint): - async def get(self) -> types.Schema[output_schema]: + async def get(self) -> None: return None else: @@ -98,15 +100,17 @@ class Dummy: return {"dummy": Dummy()} @app.route("/return-schema/", methods=["GET"]) - async def return_schema() -> types.Schema[output_schema]: + async def return_schema() -> t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(output_schema)]: return {"name": "Canna"} @app.route("/return-schema-many/", methods=["GET"]) - async def return_schema_many() -> types.Schema[output_schema]: + async def return_schema_many() -> t.Annotated[ + list[schemas.SchemaType], schemas.SchemaMetadata(output_schema) + ]: return [{"name": "Canna"}, {"name": "Sandy"}] @app.route("/return-schema-empty/", methods=["GET"]) - async def return_schema_empty() -> types.Schema[output_schema]: + async def return_schema_empty() -> None: return None @pytest.mark.parametrize( diff --git a/tests/validation/test_schemas.py b/tests/validation/test_schemas.py index 30d91a28..8f024f86 100644 --- a/tests/validation/test_schemas.py +++ b/tests/validation/test_schemas.py @@ -8,7 +8,7 @@ import typesystem import typesystem.fields -from flama import types +from flama import schemas from tests.asserts import assert_recursive_contains utc = datetime.timezone.utc @@ -125,29 +125,37 @@ def place_schema(self, app, location_schema): @pytest.fixture(scope="function", autouse=True) def add_endpoints(self, app, product_schema, reviewed_product_schema, place_schema): @app.route("/product", methods=["POST"]) - def product_identity(product: types.Schema[product_schema]) -> types.Schema[product_schema]: + def product_identity( + product: t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(product_schema)] + ) -> t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(product_schema)]: return product @app.route("/reviewed-product", methods=["POST"]) def reviewed_product_identity( - reviewed_product: types.Schema[reviewed_product_schema], - ) -> types.Schema[reviewed_product_schema]: + reviewed_product: t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(reviewed_product_schema)], + ) -> t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(reviewed_product_schema)]: return reviewed_product @app.route("/place", methods=["POST"]) - def place_identity(place: types.Schema[place_schema]) -> types.Schema[place_schema]: + def place_identity( + place: t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(place_schema)] + ) -> t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(place_schema)]: return place @app.route("/many-products", methods=["GET"]) - def many_products(products: list[types.Schema[product_schema]]) -> types.Schema[product_schema]: + def many_products( + products: t.Annotated[list[schemas.SchemaType], schemas.SchemaMetadata(product_schema)] + ) -> t.Annotated[list[schemas.SchemaType], schemas.SchemaMetadata(product_schema)]: return products @app.route("/partial-product", methods=["GET"]) - def partial_product(product: types.PartialSchema[product_schema]) -> types.PartialSchema[product_schema]: + def partial_product( + product: t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(product_schema, partial=True)] + ) -> t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(product_schema, partial=True)]: return product @app.route("/serialization-error") - def serialization_error() -> types.Schema[product_schema]: + def serialization_error() -> t.Annotated[schemas.SchemaType, schemas.SchemaMetadata(product_schema)]: return {"rating": "foo", "created": "bar"} @pytest.mark.parametrize(