Skip to content

Commit

Permalink
🐛 Include nested applications in Lifespan (#130)
Browse files Browse the repository at this point in the history
  • Loading branch information
perdy authored and migduroli committed Sep 3, 2024
1 parent 518a4eb commit 3e1c0a8
Show file tree
Hide file tree
Showing 10 changed files with 620 additions and 197 deletions.
25 changes: 23 additions & 2 deletions flama/applications.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import functools
import logging
import threading
import typing as t

from flama import asgi, http, injection, types, url, validation, websockets
from flama import asgi, exceptions, http, injection, types, url, validation, websockets
from flama.ddd.components import WorkerComponent
from flama.events import Events
from flama.middleware import MiddlewareStack
Expand All @@ -20,6 +22,8 @@

__all__ = ["Flama"]

logger = logging.getLogger(__name__)


class Flama:
def __init__(
Expand Down Expand Up @@ -57,7 +61,7 @@ def __init__(
:param schema_library: Schema library to use.
"""
self._debug = debug
self._status = types.AppStatus.NOT_INITIALIZED
self._status = types.AppStatus.NOT_STARTED
self._shutdown = False

# Create Dependency Injector
Expand Down Expand Up @@ -131,9 +135,26 @@ async def __call__(self, scope: types.Scope, receive: types.Receive, send: types
:param receive: ASGI receive event.
:param send: ASGI send event.
"""
if scope["type"] != "lifespan" and self.status in (types.AppStatus.NOT_STARTED, types.AppStatus.STARTING):
raise exceptions.ApplicationError("Application is not ready to process requests yet.")

if scope["type"] != "lifespan" and self.status in (types.AppStatus.SHUT_DOWN, types.AppStatus.SHUTTING_DOWN):
raise exceptions.ApplicationError("Application is already shut down.")

scope["app"] = self
await self.middleware(scope, receive, send)

@property
def status(self) -> types.AppStatus:
return self._status

@status.setter
def status(self, s: types.AppStatus) -> None:
logger.debug("Transitioning %s from %s to %s", self, self._status, s)

with threading.Lock():
self._status = s

@property
def components(self) -> injection.Components:
"""Components register.
Expand Down
42 changes: 16 additions & 26 deletions flama/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import contextlib
import functools
import importlib.metadata
import logging
Expand Down Expand Up @@ -46,36 +45,23 @@ async def _send(self, message: types.Message) -> None:
self._shutdown_complete.set()

async def _app_task(self) -> None:
with contextlib.suppress(asyncio.CancelledError):
scope = types.Scope({"type": "lifespan"})
scope = types.Scope({"type": "lifespan"})

try:
await self.app(scope, self._receive, self._send)
except BaseException as exc:
self._exception = exc
self._startup_complete.set()
self._shutdown_complete.set()

raise

def _run_app(self) -> None:
self._task = asyncio.get_event_loop().create_task(self._app_task())

async def _stop_app(self) -> None:
assert self._task is not None

if not self._task.done():
self._task.cancel()
try:
await self.app(scope, self._receive, self._send)
except BaseException as exc:
self._exception = exc
self._startup_complete.set()
self._shutdown_complete.set()

await self._task
raise

async def __aenter__(self) -> "LifespanContextManager":
self._run_app()
asyncio.create_task(self._app_task())

try:
await self._startup()
except BaseException:
await self._stop_app()
raise

return self
Expand All @@ -86,8 +72,12 @@ async def __aexit__(
exc_value: t.Optional[BaseException] = None,
traceback: t.Optional[TracebackType] = None,
):
await self._shutdown()
await self._stop_app()
asyncio.create_task(self._app_task())

try:
await self._shutdown()
except BaseException:
raise


class _BaseClient:
Expand Down Expand Up @@ -193,7 +183,7 @@ async def __aexit__(
await self.lifespan.__aexit__(exc_type, exc_value, traceback)
await super().__aexit__(exc_type, exc_value, traceback)

async def model_request(self, model: str, method: str, url: str, **kwargs) -> t.Awaitable[httpx.Response]:
def model_request(self, model: str, method: str, url: str, **kwargs) -> t.Awaitable[httpx.Response]:
assert self.models, "No models found for request."
return self.request(method, f"{self.models[model].rstrip('/')}{url}", **kwargs)

Expand Down
5 changes: 4 additions & 1 deletion flama/debug/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ async def sender(message: types.Message) -> None:
try:
await concurrency.run(self.app, scope, receive, sender)
except Exception as exc:
await self.process_exception(scope, receive, send, exc, response_started)
if scope["type"] in ("http", "websocket"):
await self.process_exception(scope, receive, send, exc, response_started)
else:
raise

@abc.abstractmethod
async def process_exception(
Expand Down
60 changes: 49 additions & 11 deletions flama/lifespan.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging
import typing as t

Expand All @@ -13,7 +14,14 @@

class Lifespan(types.AppClass):
def __init__(self, lifespan: t.Optional[t.Callable[[t.Optional["Flama"]], t.AsyncContextManager]] = None):
"""A class that handles the lifespan of an application.
It is responsible for calling the startup and shutdown events and the user defined lifespan.
:param lifespan: A user defined lifespan. It must be a callable that returns an async context manager.
"""
self.lifespan = lifespan
self.lock = asyncio.Lock()

async def __call__(self, scope: types.Scope, receive: types.Receive, send: types.Send) -> None:
"""Handles a lifespan request by initialising and finalising all modules and running a user defined lifespan.
Expand All @@ -22,43 +30,73 @@ async def __call__(self, scope: types.Scope, receive: types.Receive, send: types
:param receive: ASGI receive.
:param send: ASGI send.
"""
app = scope["app"]
while True:
async with self.lock:
app = scope["app"]
message = await receive()
logger.debug("Start lifespan for app '%s' from status '%s' with message '%s'", app, app.status, message)
if message["type"] == "lifespan.startup":
if app.status not in (types.AppStatus.NOT_STARTED, types.AppStatus.SHUT_DOWN):
msg = f"Trying to start application from '{app._status}' state"
await send(types.Message({"type": "lifespan.startup.failed", "message": msg}))
raise exceptions.ApplicationError(msg)

try:
logger.info("Application starting")
app._status = types.AppStatus.STARTING
app.status = types.AppStatus.STARTING
await self._startup(app)
await self._child_propagation(app, scope, message)
app.status = types.AppStatus.READY
await send(types.Message({"type": "lifespan.startup.complete"}))
app._status = types.AppStatus.READY
logger.info("Application ready")
except BaseException as e:
logger.exception("Application start failed")
app._status = types.AppStatus.FAILED
app.status = types.AppStatus.FAILED
await send(types.Message({"type": "lifespan.startup.failed", "message": str(e)}))
raise exceptions.ApplicationError("Lifespan startup failed") from e
elif message["type"] == "lifespan.shutdown":
if app.status != types.AppStatus.READY:
msg = f"Trying to shutdown application from '{app._status}' state"
await send(types.Message({"type": "lifespan.shutdown.failed", "message": msg}))
raise exceptions.ApplicationError(msg)

try:
logger.info("Application shutting down")
app._status = types.AppStatus.SHUTTING_DOWN
app.status = types.AppStatus.SHUTTING_DOWN
await self._child_propagation(app, scope, message)
await self._shutdown(app)
app.status = types.AppStatus.SHUT_DOWN
await send(types.Message({"type": "lifespan.shutdown.complete"}))
app._status = types.AppStatus.SHUT_DOWN
logger.info("Application shut down")
return
except BaseException as e:
await send(types.Message({"type": "lifespan.shutdown.failed", "message": str(e)}))
app._status = types.AppStatus.FAILED
app.status = types.AppStatus.FAILED
logger.exception("Application shutdown failed")
raise exceptions.ApplicationError("Lifespan shutdown failed") from e
else:
logger.warning("Unknown lifespan message received: %s", str(message))

logger.debug("End lifespan for app '%s' with status '%s'", app, app.status)

async def _startup(self, app: "Flama") -> None:
await concurrency.run_task_group(*(f() for f in app.events.startup))
if app.events.startup:
await concurrency.run_task_group(*(f() for f in app.events.startup))

if self.lifespan:
await self.lifespan(app).__aenter__()

async def _shutdown(self, app: "Flama") -> None:
if self.lifespan:
await self.lifespan(app).__aexit__(None, None, None)
await concurrency.run_task_group(*(f() for f in app.events.shutdown))

if app.events.shutdown:
await concurrency.run_task_group(*(f() for f in app.events.shutdown))

async def _child_propagation(self, app: "Flama", scope: types.Scope, message: types.Message) -> None:
async def child_receive() -> types.Message:
return message

async def child_send(message: types.Message) -> None:
...

if app.routes:
await concurrency.run_task_group(*(route(scope, child_receive, child_send) for route in app.routes))
78 changes: 46 additions & 32 deletions flama/routing.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import abc
import enum
import functools
import inspect
Expand Down Expand Up @@ -190,7 +191,7 @@ def __eq__(self, other) -> bool:
return isinstance(other, EndpointWrapper) and self.handler == other.handler


class BaseRoute(RouteParametersMixin):
class BaseRoute(abc.ABC, RouteParametersMixin):
def __init__(
self,
path: t.Union[str, url.RegexPath],
Expand All @@ -216,8 +217,9 @@ def __init__(
self.tags = tags or {}
super().__init__()

@abc.abstractmethod
async def __call__(self, scope: types.Scope, receive: types.Receive, send: types.Send) -> None:
await self.handle(types.Scope({**scope, **self.route_scope(scope)}), receive, send)
...

def __eq__(self, other: t.Any) -> bool:
return (
Expand Down Expand Up @@ -344,6 +346,10 @@ def __init__(

self.app: EndpointWrapper

async def __call__(self, scope: types.Scope, receive: types.Receive, send: types.Send) -> None:
if scope["type"] == "http":
await self.handle(types.Scope({**scope, **self.route_scope(scope)}), receive, send)

def __eq__(self, other: t.Any) -> bool:
return super().__eq__(other) and isinstance(other, Route) and self.methods == other.methods

Expand Down Expand Up @@ -427,6 +433,10 @@ def __init__(

self.app: EndpointWrapper

async def __call__(self, scope: types.Scope, receive: types.Receive, send: types.Send) -> None:
if scope["type"] == "websocket":
await self.handle(types.Scope({**scope, **self.route_scope(scope)}), receive, send)

def __eq__(self, other: t.Any) -> bool:
return super().__eq__(other) and isinstance(other, WebSocketRoute)

Expand Down Expand Up @@ -489,6 +499,12 @@ def __init__(

super().__init__(url.RegexPath(path.rstrip("/") + "{path:path}"), app, name=name, tags=tags)

async def __call__(self, scope: types.Scope, receive: types.Receive, send: types.Send) -> None:
if scope["type"] in ("http", "websocket") or (
scope["type"] == "lifespan" and types.is_flama_instance(self.app)
):
await self.handle(types.Scope({**scope, **self.route_scope(scope)}), receive, send)

def __eq__(self, other: t.Any) -> bool:
return super().__eq__(other) and isinstance(other, Mount)

Expand All @@ -499,12 +515,10 @@ def build(self, app: t.Optional["Flama"] = None) -> None:
:param app: Flama app.
"""
from flama import Flama

if app and isinstance(self.app, Flama):
if app and types.is_flama_instance(self.app):
self.app.router.components = Components(self.app.router.components + app.components)

if root := (self.app if isinstance(self.app, Flama) else app):
if root := (self.app if types.is_flama_instance(self.app) else app):
for route in self.routes:
route.build(root)

Expand All @@ -531,30 +545,33 @@ async def handle(self, scope: types.Scope, receive: types.Receive, send: types.S
def route_scope(self, scope: types.Scope) -> types.Scope:
"""Build route scope from given scope.
It generates an updated scope parameters for the route:
* app: The app of this mount point. If it's mounting a Flama app, it will replace the app with this one
* path_params: The matched path parameters of this mount point
* endpoint: The endpoint of this mount point
* root_path: The root path of this mount point (if it's mounting a Flama app, it will be empty)
* path: The remaining path to be matched
:param scope: ASGI scope.
:return: Route scope.
"""
from flama import Flama
result = {"app": self.app if types.is_flama_instance(self.app) else scope["app"]}

if "path" in scope:
path = scope["path"]
matched_params = self.path.values(path)
remaining_path = matched_params.pop("path")
matched_path = path[: -len(remaining_path)]
result.update(
{
"path_params": {**dict(scope.get("path_params", {})), **matched_params},
"endpoint": self.endpoint,
"root_path": "" if types.is_flama_instance(self.app) else scope.get("root_path", "") + matched_path,
"path": remaining_path,
}
)

path = scope["path"]
matched_params = self.path.values(path)
remaining_path = matched_params.pop("path")
matched_path = path[: -len(remaining_path)]
if isinstance(self.app, Flama):
app = self.app
root_path = ""
else:
app = scope["app"]
root_path = scope.get("root_path", "") + matched_path
return types.Scope(
{
"app": app,
"path_params": {**dict(scope.get("path_params", {})), **matched_params},
"endpoint": self.endpoint,
"root_path": root_path,
"path": remaining_path,
}
)
return types.Scope(result)

def resolve_url(self, name: str, **params: t.Any) -> url.URL:
"""Builds URL path for given name and params.
Expand Down Expand Up @@ -620,16 +637,13 @@ async def __call__(self, scope: types.Scope, receive: types.Receive, send: types
logger.debug("Request: %s", str(scope))
assert scope["type"] in ("http", "websocket", "lifespan")

if "app" in scope and scope["app"]._status != types.AppStatus.READY and scope["type"] != "lifespan":
raise exceptions.ApplicationError("Application is not ready to process requests yet.")

if "router" not in scope:
scope["router"] = self

if scope["type"] == "lifespan":
await self.lifespan(scope, receive, send)
return

if "router" not in scope:
scope["router"] = self

route, route_scope = self.resolve_route(scope)
await route(route_scope, receive, send)

Expand Down
Loading

0 comments on commit 3e1c0a8

Please sign in to comment.