diff --git a/src/dishka/integrations/litestar.py b/src/dishka/integrations/litestar.py index 025386b7..abe93880 100644 --- a/src/dishka/integrations/litestar.py +++ b/src/dishka/integrations/litestar.py @@ -2,8 +2,11 @@ from typing import Optional, Sequence, get_type_hints from litestar import Litestar, Request +from litestar.types import ASGIApp, Receive, Scope, Send from dishka import Provider, make_async_container +from dishka.async_container import AsyncContainer, AsyncContextWrapper +from dishka.integrations.asgi import BaseDishkaApp from dishka.integrations.base import wrap_injection @@ -32,22 +35,16 @@ def inject(func): ) -async def make_dishka_container(request: Request): - request_container = await request.app.state.dishka_container().__aenter__() - request.state.dishka_container = request_container +def make_add_request_container_middleware(app: ASGIApp): + async def middleware(scope: Scope, receive: Receive, send: Send) -> None: + request = Request(scope) + async with request.app.state.dishka_container( + {Request: request}, + ) as request_container: + request.state.dishka_container = request_container + await app(scope, receive, send) - -async def close_dishka_container(request: Request): - await request.app.state.dishka_container().__aexit__(None, None, None) - - -async def startup_dishka(app: Litestar): - container = await app.state.dishka_container_wrapper.__aenter__() - app.state.dishka_container = container - - -async def shutdown_dishka(app: Litestar): - await app.state.dishka_container_wrapper.__aexit__(None, None, None) + return middleware def setup_dishka(app: Litestar, providers: Sequence[Provider]) -> Litestar: @@ -55,14 +52,13 @@ def setup_dishka(app: Litestar, providers: Sequence[Provider]) -> Litestar: return app -class DishkaApp: - def __init__(self, providers: Sequence[Provider], app: Litestar): - self.app = app - self.app.state.dishka_container_wrapper = make_async_container(*providers) - self.app.on_startup.append(startup_dishka) - self.app.on_shutdown.append(shutdown_dishka) - self.app.before_request = make_dishka_container - self.app.after_response = close_dishka_container +class DishkaApp(BaseDishkaApp): + def _init_request_middleware( + self, app, container_wrapper: AsyncContextWrapper, + ): + app.asgi_handler = make_add_request_container_middleware( + app.asgi_handler, + ) - async def __call__(self, scope, receive, send): - await self.app(scope, receive, send) + def _app_startup(self, app, container: AsyncContainer): + app.state.dishka_container = container