Skip to content

Commit

Permalink
fix: LifespanHandler memory stream cleanup (#3836)
Browse files Browse the repository at this point in the history
  • Loading branch information
provinzkraut authored Oct 21, 2024
1 parent b2adb0d commit 5255ec3
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 4 deletions.
1 change: 1 addition & 0 deletions litestar/testing/client/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ async def __aenter__(self) -> Self:
async with AsyncExitStack() as stack:
self.blocking_portal = portal = stack.enter_context(self.portal())
self.lifespan_handler = LifeSpanHandler(client=self)
stack.enter_context(self.lifespan_handler)

@stack.callback
def reset_portal() -> None:
Expand Down
1 change: 1 addition & 0 deletions litestar/testing/client/sync_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __enter__(self) -> Self:
with ExitStack() as stack:
self.blocking_portal = portal = stack.enter_context(self.portal())
self.lifespan_handler = LifeSpanHandler(client=self)
stack.enter_context(self.lifespan_handler)

@stack.callback
def reset_portal() -> None:
Expand Down
54 changes: 53 additions & 1 deletion litestar/testing/life_span_handler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import warnings
from math import inf
from typing import TYPE_CHECKING, Generic, Optional, TypeVar, cast

Expand All @@ -9,6 +10,8 @@
from litestar.testing.client.base import BaseTestClient

if TYPE_CHECKING:
from types import TracebackType

from litestar.types import (
LifeSpanReceiveMessage, # noqa: F401
LifeSpanSendMessage,
Expand All @@ -20,24 +23,69 @@


class LifeSpanHandler(Generic[T]):
__slots__ = "stream_send", "stream_receive", "client", "task"
__slots__ = (
"stream_send",
"stream_receive",
"client",
"task",
"_startup_done",
)

def __init__(self, client: T) -> None:
self.client = client
self.stream_send = StapledObjectStream[Optional["LifeSpanSendMessage"]](*create_memory_object_stream(inf)) # type: ignore[arg-type]
self.stream_receive = StapledObjectStream["LifeSpanReceiveMessage"](*create_memory_object_stream(inf)) # type: ignore[arg-type]
self._startup_done = False

def _ensure_setup(self, is_safe: bool = False) -> None:
if self._startup_done:
return

if not is_safe:
warnings.warn(
"LifeSpanHandler used with implicit startup; Use LifeSpanHandler as a context manager instead. "
"Implicit startup will be deprecated in version 3.0.",
category=DeprecationWarning,
stacklevel=2,
)

self._startup_done = True
with self.client.portal() as portal:
self.task = portal.start_task_soon(self.lifespan)
portal.call(self.wait_startup)

def close(self) -> None:
with self.client.portal() as portal:
portal.call(self.stream_send.aclose)
portal.call(self.stream_receive.aclose)

def __enter__(self) -> LifeSpanHandler:
try:
self._ensure_setup(is_safe=True)
except Exception as exc:
self.close()
raise exc
return self

def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
self.close()

async def receive(self) -> LifeSpanSendMessage:
self._ensure_setup()

message = await self.stream_send.receive()
if message is None:
self.task.result()
return cast("LifeSpanSendMessage", message)

async def wait_startup(self) -> None:
self._ensure_setup()

event: LifeSpanStartupEvent = {"type": "lifespan.startup"}
await self.stream_receive.send(event)

Expand All @@ -54,6 +102,8 @@ async def wait_startup(self) -> None:
await self.receive()

async def wait_shutdown(self) -> None:
self._ensure_setup()

async with self.stream_send:
lifespan_shutdown_event: LifeSpanShutdownEvent = {"type": "lifespan.shutdown"}
await self.stream_receive.send(lifespan_shutdown_event)
Expand All @@ -71,6 +121,8 @@ async def wait_shutdown(self) -> None:
await self.receive()

async def lifespan(self) -> None:
self._ensure_setup()

scope = {"type": "lifespan"}
try:
await self.client.app(scope, self.stream_receive.receive, self.stream_send.send)
Expand Down
19 changes: 16 additions & 3 deletions tests/unit/test_testing/test_lifespan_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,34 @@
from litestar.testing.life_span_handler import LifeSpanHandler
from litestar.types import Receive, Scope, Send

pytestmark = pytest.mark.filterwarnings("default")


async def test_wait_startup_invalid_event() -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
await send({"type": "lifespan.startup.something_unexpected"}) # type: ignore[typeddict-item]

with pytest.raises(RuntimeError, match="Received unexpected ASGI message type"):
LifeSpanHandler(TestClient(app))
with LifeSpanHandler(TestClient(app)):
pass


async def test_wait_shutdown_invalid_event() -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
await send({"type": "lifespan.startup.complete"}) # type: ignore[typeddict-item]
await send({"type": "lifespan.shutdown.something_unexpected"}) # type: ignore[typeddict-item]

handler = LifeSpanHandler(TestClient(app))
with LifeSpanHandler(TestClient(app)) as handler:
with pytest.raises(RuntimeError, match="Received unexpected ASGI message type"):
await handler.wait_shutdown()

with pytest.raises(RuntimeError, match="Received unexpected ASGI message type"):

async def test_implicit_startup() -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
await send({"type": "lifespan.startup.complete"}) # type: ignore[typeddict-item]
await send({"type": "lifespan.shutdown.complete"}) # type: ignore[typeddict-item]

with pytest.warns(DeprecationWarning):
handler = LifeSpanHandler(TestClient(app))
await handler.wait_shutdown()
handler.close()

0 comments on commit 5255ec3

Please sign in to comment.