Skip to content

Commit

Permalink
Avoid raising ConnectionLostError in ack/nack (#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
vrslev authored Jun 21, 2024
1 parent 4a701eb commit fc2d7e3
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 16 deletions.
32 changes: 17 additions & 15 deletions stompman/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable
from contextlib import AsyncExitStack, asynccontextmanager
from contextlib import AsyncExitStack, asynccontextmanager, suppress
from dataclasses import dataclass, field
from types import TracebackType
from typing import Any, ClassVar, NamedTuple, Self, TypedDict
Expand Down Expand Up @@ -291,25 +291,27 @@ def __post_init__(self) -> None:

async def ack(self) -> None:
if self._client._connection.active:
await self._client._connection.write_frame(
AckFrame(
headers={
"id": self._frame.headers["message-id"],
"subscription": self._frame.headers["subscription"],
},
with suppress(ConnectionLostError):
await self._client._connection.write_frame(
AckFrame(
headers={
"id": self._frame.headers["message-id"],
"subscription": self._frame.headers["subscription"],
},
)
)
)

async def nack(self) -> None:
if self._client._connection.active:
await self._client._connection.write_frame(
NackFrame(
headers={
"id": self._frame.headers["message-id"],
"subscription": self._frame.headers["subscription"],
}
with suppress(ConnectionLostError):
await self._client._connection.write_frame(
NackFrame(
headers={
"id": self._frame.headers["message-id"],
"subscription": self._frame.headers["subscription"],
}
)
)
)

async def with_auto_ack(
self,
Expand Down
20 changes: 19 additions & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ async def test_client_listen_to_events_unreachable(frame: ConnectedFrame | Recei
[event async for event in client.listen()]


async def test_ack_nack() -> None:
async def test_ack_nack_ok() -> None:
subscription = "subscription-id"
message_id = "message-id"

Expand All @@ -391,6 +391,24 @@ async def test_ack_nack() -> None:
assert_frames_between_lifespan_match(collected_frames, [message_frame, nack_frame, ack_frame])


async def test_ack_nack_connection_lost_error() -> None:
message_frame = MessageFrame(headers={"subscription": "", "message-id": "", "destination": ""}, body=b"")
connection_class, _ = create_spying_connection(get_read_frames_with_lifespan([[message_frame]]))

class MockConnection(connection_class): # type: ignore[valid-type, misc]
async def write_frame(self, frame: AnyClientFrame) -> None:
if isinstance(frame, AckFrame | NackFrame):
raise ConnectionLostError

async with EnrichedClient(connection_class=MockConnection) as client:
events = [event async for event in client.listen()]
event = events[0]
assert isinstance(event, MessageEvent)

await event.nack()
await event.ack()


def get_mocked_message_event() -> tuple[MessageEvent, mock.AsyncMock, mock.AsyncMock, mock.Mock]:
ack_mock, nack_mock, on_suppressed_exception_mock = mock.AsyncMock(), mock.AsyncMock(), mock.Mock()

Expand Down

0 comments on commit fc2d7e3

Please sign in to comment.