Skip to content

Commit

Permalink
Fix graceful shutdown when using uvloop (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
vrslev authored Jun 17, 2024
1 parent 9730d08 commit db05208
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 47 deletions.
6 changes: 3 additions & 3 deletions stompman/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ async def listen(self) -> AsyncIterator["AnyListeningEvent"]:
class MessageEvent:
body: bytes = field(init=False)
_frame: MessageFrame
_client: "Client" = field(repr=False)
_client: Client = field(repr=False)

def __post_init__(self) -> None:
self.body = self._frame.body
Expand Down Expand Up @@ -314,7 +314,7 @@ class ErrorEvent:
body: bytes = field(init=False)
"""Long description of the error."""
_frame: ErrorFrame
_client: "Client" = field(repr=False)
_client: Client = field(repr=False)

def __post_init__(self) -> None:
self.message_header = self._frame.headers["message"]
Expand All @@ -324,7 +324,7 @@ def __post_init__(self) -> None:
@dataclass
class HeartbeatEvent:
_frame: HeartbeatFrame
_client: "Client" = field(repr=False)
_client: Client = field(repr=False)


AnyListeningEvent = MessageEvent | ErrorEvent | HeartbeatEvent
3 changes: 2 additions & 1 deletion stompman/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def write_heartbeat(self) -> None:
return self.writer.write(NEWLINE)

async def write_frame(self, frame: AnyClientFrame) -> None:
self.writer.write(dump_frame(frame))
with _reraise_connection_lost(RuntimeError):
self.writer.write(dump_frame(frame))
with _reraise_connection_lost(ConnectionError):
await self.writer.drain()

Expand Down
1 change: 0 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
pytest.param(("asyncio", {"use_uvloop": True}), id="asyncio+uvloop"),
pytest.param(("asyncio", {"use_uvloop": False}), id="asyncio"),
],
autouse=True,
)
def anyio_backend(request: pytest.FixtureRequest) -> object:
return request.param
19 changes: 17 additions & 2 deletions tests/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,20 @@
import os
from uuid import uuid4

import pytest

import stompman
from stompman.errors import ConnectionLostError

pytestmark = pytest.mark.anyio


@pytest.fixture()
def server() -> stompman.ConnectionParameters:
return stompman.ConnectionParameters(host=os.environ["ARTEMIS_HOST"], port=61616, login="admin", passcode="admin")

async def test_integration() -> None:
server = stompman.ConnectionParameters(host=os.environ["ARTEMIS_HOST"], port=61616, login="admin", passcode="admin")

async def test_ok(server: stompman.ConnectionParameters) -> None:
destination = "DLQ"
messages = [str(uuid4()).encode() for _ in range(10000)]

Expand Down Expand Up @@ -38,3 +47,9 @@ async def consume() -> None:
):
task_group.create_task(consume())
task_group.create_task(produce())


async def test_raises_connection_lost_error(server: stompman.ConnectionParameters) -> None:
with pytest.raises(ConnectionLostError):
async with stompman.Client(servers=[server], read_timeout=10, connection_confirmation_timeout=10) as consumer:
await consumer._connection.close()
2 changes: 2 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
)
from stompman.client import ConnectionParameters, ErrorEvent, HeartbeatEvent, MessageEvent

pytestmark = pytest.mark.anyio


@dataclass
class BaseMockConnection(AbstractConnection):
Expand Down
77 changes: 37 additions & 40 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,27 @@

import pytest

from stompman import (
AnyServerFrame,
ConnectedFrame,
Connection,
ConnectionLostError,
HeartbeatFrame,
)
from stompman import AnyServerFrame, ConnectedFrame, Connection, ConnectionLostError, HeartbeatFrame
from stompman.frames import BeginFrame, CommitFrame

pytestmark = pytest.mark.anyio


async def make_connection() -> Connection | None:
return await Connection.connect(host="localhost", port=12345, timeout=2)


async def make_mocked_connection(
monkeypatch: pytest.MonkeyPatch,
reader: Any, # noqa: ANN401
writer: Any, # noqa: ANN401
) -> Connection:
monkeypatch.setattr("asyncio.open_connection", mock.AsyncMock(return_value=(reader, writer)))
connection = await make_connection()
assert connection
return connection


def mock_wait_for(monkeypatch: pytest.MonkeyPatch) -> None:
async def mock_impl(future: Awaitable[Any], timeout: int) -> Any: # noqa: ANN401, ARG001
return await original_wait_for(future, timeout=0)
Expand Down Expand Up @@ -57,19 +64,21 @@ class MockWriter:
b"som",
b"e server\nversion:1.2\n\n\x00",
]
expected_frames = [
HeartbeatFrame(),
HeartbeatFrame(),
HeartbeatFrame(),
ConnectedFrame(headers={"heart-beat": "0,0", "version": "1.2", "server": "some server"}),
]
max_chunk_size = 1024

class MockReader:
read = mock.AsyncMock(side_effect=read_bytes)

monkeypatch.setattr("asyncio.open_connection", mock.AsyncMock(return_value=(MockReader(), MockWriter())))
connection = await make_connection()
assert connection

connection = await make_mocked_connection(monkeypatch, MockReader(), MockWriter())
connection.write_heartbeat()
await connection.write_frame(CommitFrame(headers={"transaction": "transaction"}))

max_chunk_size = 1024

async def take_frames(count: int) -> list[AnyServerFrame]:
frames = []
async for frame in connection.read_frames(max_chunk_size=max_chunk_size, timeout=1):
Expand All @@ -79,12 +88,6 @@ async def take_frames(count: int) -> list[AnyServerFrame]:

return frames

expected_frames = [
HeartbeatFrame(),
HeartbeatFrame(),
HeartbeatFrame(),
ConnectedFrame(headers={"heart-beat": "0,0", "version": "1.2", "server": "some server"}),
]
assert await take_frames(len(expected_frames)) == expected_frames
await connection.close()

Expand All @@ -100,10 +103,7 @@ class MockWriter:
close = mock.Mock()
wait_closed = mock.AsyncMock(side_effect=ConnectionError)

monkeypatch.setattr("asyncio.open_connection", mock.AsyncMock(return_value=(mock.Mock(), MockWriter())))
connection = await make_connection()
assert connection

connection = await make_mocked_connection(monkeypatch, mock.Mock(), MockWriter())
with pytest.raises(ConnectionLostError):
await connection.close()

Expand All @@ -113,10 +113,17 @@ class MockWriter:
write = mock.Mock()
drain = mock.AsyncMock(side_effect=ConnectionError)

monkeypatch.setattr("asyncio.open_connection", mock.AsyncMock(return_value=(mock.Mock(), MockWriter())))
connection = await make_connection()
assert connection
connection = await make_mocked_connection(monkeypatch, mock.Mock(), MockWriter())
with pytest.raises(ConnectionLostError):
await connection.write_frame(BeginFrame(headers={"transaction": ""}))


async def test_connection_write_frame_runtime_error(monkeypatch: pytest.MonkeyPatch) -> None:
class MockWriter:
write = mock.Mock(side_effect=RuntimeError)
drain = mock.AsyncMock()

connection = await make_mocked_connection(monkeypatch, mock.Mock(), MockWriter())
with pytest.raises(ConnectionLostError):
await connection.write_frame(BeginFrame(headers={"transaction": ""}))

Expand All @@ -133,27 +140,17 @@ async def test_connection_connect_connection_error(monkeypatch: pytest.MonkeyPat


async def test_read_frames_timeout_error(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"asyncio.open_connection",
mock.AsyncMock(return_value=(mock.AsyncMock(read=partial(asyncio.sleep, 5)), mock.AsyncMock())),
connection = await make_mocked_connection(
monkeypatch, mock.AsyncMock(read=partial(asyncio.sleep, 5)), mock.AsyncMock()
)
connection = await make_connection()
assert connection

mock_wait_for(monkeypatch)
with pytest.raises(ConnectionLostError):
[frame async for frame in connection.read_frames(1024, 1)]


async def test_read_frames_connection_error(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"asyncio.open_connection",
mock.AsyncMock(
return_value=(mock.AsyncMock(read=mock.AsyncMock(side_effect=BrokenPipeError)), mock.AsyncMock())
),
connection = await make_mocked_connection(
monkeypatch, mock.AsyncMock(read=mock.AsyncMock(side_effect=BrokenPipeError)), mock.AsyncMock()
)
connection = await make_connection()
assert connection

with pytest.raises(ConnectionLostError):
[frame async for frame in connection.read_frames(1024, 1)]

0 comments on commit db05208

Please sign in to comment.