Skip to content

Commit

Permalink
Avoid sending cleanup frames if ConnectionLostError was raised, also …
Browse files Browse the repository at this point in the history
…avoid raising it in ack/nack (#30)
  • Loading branch information
vrslev authored Jun 21, 2024
1 parent 4120df9 commit 4a701eb
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 54 deletions.
45 changes: 29 additions & 16 deletions stompman/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable
from contextlib import AsyncExitStack, asynccontextmanager, suppress
from contextlib import AsyncExitStack, asynccontextmanager
from dataclasses import dataclass, field
from types import TracebackType
from typing import Any, ClassVar, NamedTuple, Self, TypedDict
Expand Down Expand Up @@ -145,8 +145,7 @@ async def _connect_to_any_server(self) -> None:
for maybe_connection_future in asyncio.as_completed(
[self._connect_to_one_server(server) for server in self.servers]
):
maybe_result = await maybe_connection_future
if maybe_result:
if maybe_result := await maybe_connection_future:
self._connection, self._connection_parameters = maybe_result
return
raise FailedAllConnectAttemptsError(
Expand Down Expand Up @@ -199,10 +198,12 @@ async def _connection_lifespan(self) -> AsyncGenerator[None, None]:
)

async def send_heartbeats_forever() -> None:
while True:
while self._connection.active:
try:
self._connection.write_heartbeat()
except ConnectionLostError:
# Avoid raising the error in an exception group.
# ConnectionLostError should be raised in a way that user expects it.
return
await asyncio.sleep(heartbeat_interval)

Expand All @@ -213,8 +214,9 @@ async def send_heartbeats_forever() -> None:
finally:
task.cancel()

with suppress(ConnectionLostError):
if self._connection.active:
await self._connection.write_frame(DisconnectFrame(headers={"receipt": str(uuid4())}))
if self._connection.active:
await self._connection.read_frame_of_type(
ReceiptFrame, max_chunk_size=self.read_max_chunk_size, timeout=self.read_timeout
)
Expand All @@ -227,10 +229,12 @@ async def enter_transaction(self) -> AsyncGenerator[str, None]:
try:
yield transaction_id
except Exception:
await self._connection.write_frame(AbortFrame(headers={"transaction": transaction_id}))
if self._connection.active:
await self._connection.write_frame(AbortFrame(headers={"transaction": transaction_id}))
raise
else:
await self._connection.write_frame(CommitFrame(headers={"transaction": transaction_id}))
if self._connection.active:
await self._connection.write_frame(CommitFrame(headers={"transaction": transaction_id}))

async def send( # noqa: PLR0913
self,
Expand Down Expand Up @@ -258,7 +262,8 @@ async def subscribe(self, destination: str) -> AsyncGenerator[None, None]:
try:
yield
finally:
await self._connection.write_frame(UnsubscribeFrame(headers={"id": subscription_id}))
if self._connection.active:
await self._connection.write_frame(UnsubscribeFrame(headers={"id": subscription_id}))

async def listen(self) -> AsyncIterator["AnyListeningEvent"]:
async for frame in self._connection.read_frames(
Expand All @@ -285,18 +290,26 @@ def __post_init__(self) -> None:
self.body = self._frame.body

async def ack(self) -> None:
await self._client._connection.write_frame(
AckFrame(
headers={"id": self._frame.headers["message-id"], "subscription": self._frame.headers["subscription"]},
if self._client._connection.active:
await self._client._connection.write_frame(
AckFrame(
headers={
"id": self._frame.headers["message-id"],
"subscription": self._frame.headers["subscription"],
},
)
)
)

async def nack(self) -> None:
await self._client._connection.write_frame(
NackFrame(
headers={"id": self._frame.headers["message-id"], "subscription": self._frame.headers["subscription"]}
if self._client._connection.active:
await self._client._connection.write_frame(
NackFrame(
headers={
"id": self._frame.headers["message-id"],
"subscription": self._frame.headers["subscription"],
}
)
)
)

async def with_auto_ack(
self,
Expand Down
29 changes: 16 additions & 13 deletions stompman/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

@dataclass
class AbstractConnection(Protocol):
active: bool = True

@classmethod
async def connect(cls, host: str, port: int, timeout: int) -> Self | None: ...
async def close(self) -> None: ...
Expand All @@ -28,15 +30,7 @@ async def read_frame_of_type(self, type_: type[FrameType], max_chunk_size: int,
return frame


@contextmanager
def _reraise_connection_lost(*causes: type[Exception]) -> Generator[None, None, None]:
try:
yield
except causes as exception:
raise ConnectionLostError from exception


@dataclass
@dataclass(kw_only=True)
class Connection(AbstractConnection):
reader: asyncio.StreamReader
writer: asyncio.StreamWriter
Expand All @@ -54,15 +48,24 @@ async def close(self) -> None:
self.writer.close()
with suppress(ConnectionError):
await self.writer.wait_closed()
self.active = False

@contextmanager
def _reraise_connection_lost(self, *causes: type[Exception]) -> Generator[None, None, None]:
try:
yield
except causes as exception:
self.active = False
raise ConnectionLostError from exception

def write_heartbeat(self) -> None:
with _reraise_connection_lost(RuntimeError):
with self._reraise_connection_lost(RuntimeError):
return self.writer.write(NEWLINE)

async def write_frame(self, frame: AnyClientFrame) -> None:
with _reraise_connection_lost(RuntimeError):
with self._reraise_connection_lost(RuntimeError):
self.writer.write(dump_frame(frame))
with _reraise_connection_lost(ConnectionError):
with self._reraise_connection_lost(ConnectionError):
await self.writer.drain()

async def _read_non_empty_bytes(self, max_chunk_size: int) -> bytes:
Expand All @@ -74,7 +77,7 @@ async def read_frames(self, max_chunk_size: int, timeout: int) -> AsyncGenerator
parser = FrameParser()

while True:
with _reraise_connection_lost(ConnectionError, TimeoutError):
with self._reraise_connection_lost(ConnectionError, TimeoutError):
raw_frames = await asyncio.wait_for(self._read_non_empty_bytes(max_chunk_size), timeout=timeout)

for frame in cast(Iterator[AnyServerFrame], parser.parse_frames_from_chunk(raw_frames)):
Expand Down
89 changes: 64 additions & 25 deletions tests/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,27 @@
pytestmark = pytest.mark.anyio


@asynccontextmanager
async def create_client() -> AsyncGenerator[stompman.Client, None]:
server = stompman.ConnectionParameters(
host=os.environ["ARTEMIS_HOST"], port=61616, login="admin", passcode="%3D123"
)
async with stompman.Client(servers=[server], read_timeout=10, connection_confirmation_timeout=10) as client:
yield client


@pytest.fixture()
def server() -> stompman.ConnectionParameters:
return stompman.ConnectionParameters(host=os.environ["ARTEMIS_HOST"], port=61616, login="admin", passcode="%3D123")
async def client() -> AsyncGenerator[stompman.Client, None]:
async with create_client() as client:
yield client


async def test_ok(server: stompman.ConnectionParameters) -> None:
destination = "DLQ"
messages = [str(uuid4()).encode() for _ in range(10000)]
@pytest.fixture()
def destination() -> str:
return "DLQ"


async def test_ok(destination: str) -> None:
async def produce() -> None:
async with producer.enter_transaction() as transaction:
for message in messages:
Expand All @@ -42,35 +54,62 @@ async def consume() -> None:

assert sorted(received_messages) == sorted(messages)

async with (
stompman.Client(servers=[server], read_timeout=10, connection_confirmation_timeout=10) as consumer,
stompman.Client(servers=[server], read_timeout=10, connection_confirmation_timeout=10) as producer,
asyncio.TaskGroup() as task_group,
):
messages = [str(uuid4()).encode() for _ in range(10000)]

async with create_client() as consumer, create_client() as producer, asyncio.TaskGroup() as task_group:
task_group.create_task(consume())
task_group.create_task(produce())


@asynccontextmanager
async def closed_client(server: stompman.ConnectionParameters) -> AsyncGenerator[stompman.Client, None]:
async with stompman.Client(servers=[server], read_timeout=10, connection_confirmation_timeout=10) as client:
async def test_not_raises_connection_lost_error_in_aexit(client: stompman.Client) -> None:
await client._connection.close()


async def test_not_raises_connection_lost_error_in_write_frame(client: stompman.Client) -> None:
await client._connection.close()

with pytest.raises(ConnectionLostError):
await client._connection.write_frame(stompman.ConnectFrame(headers={"accept-version": "", "host": ""}))


@pytest.mark.parametrize("anyio_backend", [("asyncio", {"use_uvloop": True})])
async def test_not_raises_connection_lost_error_in_write_heartbeat(client: stompman.Client) -> None:
await client._connection.close()

with pytest.raises(ConnectionLostError):
client._connection.write_heartbeat()


async def test_not_raises_connection_lost_error_in_subscription(client: stompman.Client, destination: str) -> None:
async with client.subscribe(destination):
await client._connection.close()
yield client


async def test_not_raises_connection_lost_error_in_aexit(server: stompman.ConnectionParameters) -> None:
async with closed_client(server):
pass
async def test_not_raises_connection_lost_error_in_transaction_without_send(client: stompman.Client) -> None:
async with client.enter_transaction():
await client._connection.close()


async def test_not_raises_connection_lost_error_in_write_frame(server: stompman.ConnectionParameters) -> None:
async with closed_client(server) as client:
async def test_not_raises_connection_lost_error_in_transaction_with_send(
client: stompman.Client, destination: str
) -> None:
async with client.enter_transaction() as transaction:
await client.send(b"first", destination=destination, transaction=transaction)
await client._connection.close()

with pytest.raises(ConnectionLostError):
await client._connection.write_frame(stompman.ConnectFrame(headers={"accept-version": "", "host": ""}))
await client.send(b"second", destination=destination, transaction=transaction)


@pytest.mark.parametrize("anyio_backend", [("asyncio", {"use_uvloop": True})])
async def test_not_raises_connection_lost_error_in_write_heartbeat(server: stompman.ConnectionParameters) -> None:
async with closed_client(server) as client:
with pytest.raises(ConnectionLostError):
client._connection.write_heartbeat()
async def test_raises_connection_lost_error_in_send(client: stompman.Client, destination: str) -> None:
await client._connection.close()

with pytest.raises(ConnectionLostError):
await client.send(b"first", destination=destination)


async def test_raises_connection_lost_error_in_listen(client: stompman.Client) -> None:
await client._connection.close()
client.read_timeout = 0
with pytest.raises(ConnectionLostError):
[event async for event in client.listen()]

0 comments on commit 4a701eb

Please sign in to comment.