From 03316b7726df244db3f20bccff106ce9012ebbff Mon Sep 17 00:00:00 2001 From: MZauchner Date: Thu, 5 Dec 2024 21:04:20 +0100 Subject: [PATCH 1/2] Close connections in case of an exception When an exception occurs while the connection is in use, the connection may become unsuable. Such connections need to be closed, such that the connection to the broker can be reestablished at a later point. --- kombu/connection.py | 22 ++++++++++- kombu/messaging.py | 8 +++- t/integration/test_py_amqp.py | 74 ++++++++++++++++++++++++++++++++++- t/unit/test_connection.py | 37 ++++++++++++++++++ t/unit/test_pools.py | 14 ++++++- 5 files changed, 150 insertions(+), 5 deletions(-) diff --git a/kombu/connection.py b/kombu/connection.py index 32c416b00..d30b58364 100644 --- a/kombu/connection.py +++ b/kombu/connection.py @@ -1035,6 +1035,26 @@ def is_evented(self): BrokerConnection = Connection +class PooledConnection(Connection): + """Wraps :class:`kombu.Connection`. + + This wrapper modifies :meth:`kombu.Connection.__exit__` to close the connection + in case any exception occurred while the context was active. + """ + + def __init__(self, pool, **kwargs): + self._pool = pool + super().__init__(**kwargs) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is not None and self._pool.limit: + self._pool.replace(self) + return super().__exit__(exc_type, exc_val, exc_tb) + + class ConnectionPool(Resource): """Pool of connections.""" @@ -1046,7 +1066,7 @@ def __init__(self, connection, limit=None, **kwargs): super().__init__(limit=limit) def new(self): - return self.connection.clone() + return PooledConnection(self, **dict(self.connection._info(resolve=False))) def release_resource(self, resource): try: diff --git a/kombu/messaging.py b/kombu/messaging.py index dcbc0c227..dc9f10b21 100644 --- a/kombu/messaging.py +++ b/kombu/messaging.py @@ -7,7 +7,7 @@ from .common import maybe_declare from .compression import compress -from .connection import is_connection, maybe_channel +from .connection import PooledConnection, is_connection, maybe_channel from .entity import Exchange, Queue, maybe_delivery_mode from .exceptions import ContentDisallowed from .serialization import dumps, prepare_accept_content @@ -257,6 +257,12 @@ def __exit__( exc_val: BaseException | None, exc_tb: TracebackType | None ) -> None: + # In case the connection is part of a pool it needs to be + # replaced in case of an exception + if self.__connection__ is not None and exc_type is not None: + if isinstance(self.__connection__, PooledConnection): + self.__connection__._pool.replace(self.__connection__) + self.release() def release(self): diff --git a/t/integration/test_py_amqp.py b/t/integration/test_py_amqp.py index 260f164d8..f66d3cd55 100644 --- a/t/integration/test_py_amqp.py +++ b/t/integration/test_py_amqp.py @@ -1,13 +1,23 @@ from __future__ import annotations import os +import uuid import pytest import kombu +from amqp.exceptions import NotFound -from .common import (BaseExchangeTypes, BaseFailover, BaseMessage, - BasePriority, BaseTimeToLive, BasicFunctionality) +from kombu.connection import ConnectionPool + +from .common import ( + BaseExchangeTypes, + BaseFailover, + BaseMessage, + BasePriority, + BaseTimeToLive, + BasicFunctionality, +) def get_connection(hostname, port, vhost): @@ -20,6 +30,12 @@ def get_failover_connection(hostname, port, vhost): ) +def get_confirm_connection(hostname, port): + return kombu.Connection( + f"pyamqp://{hostname}:{port}", transport_options={"confirm_publish": True} + ) + + @pytest.fixture() def invalid_connection(): return kombu.Connection('pyamqp://localhost:12345') @@ -47,6 +63,14 @@ def failover_connection(request): ) +@pytest.fixture() +def confirm_publish_connection(): + return get_confirm_connection( + hostname=os.environ.get("RABBITMQ_HOST", "localhost"), + port=os.environ.get("RABBITMQ_5672_TCP", "5672"), + ) + + @pytest.mark.env('py-amqp') @pytest.mark.flaky(reruns=5, reruns_delay=2) class test_PyAMQPBasicFunctionality(BasicFunctionality): @@ -81,3 +105,49 @@ class test_PyAMQPFailover(BaseFailover): @pytest.mark.flaky(reruns=5, reruns_delay=2) class test_PyAMQPMessage(BaseMessage): pass + + +@pytest.mark.env("py-amqp") +@pytest.mark.flaky(reruns=5, reruns_delay=2) +class test_PyAMQPConnectionPool: + def test_publish_confirm_does_not_block(self, confirm_publish_connection): + """Tests that the connection pool closes connections in case of an exception. + + In case an exception occurs while the connection is in use, the pool should + close the exception. In case the connection is not closed before releasing it + back to the pool, the connection would remain in an unsuable state, causing + causing the next publish call to time out or block forever in case no + timeout is specified. + """ + pool = ConnectionPool(connection=confirm_publish_connection, limit=1) + + try: + with pool.acquire(block=True) as connection: + producer = kombu.Producer(connection) + queue = kombu.Queue( + "test-queue-{}".format(uuid.uuid4()), channel=connection + ) + queue.declare() + producer.publish( + {"foo": "bar"}, routing_key=str(uuid.uuid4()), retry=False + ) + assert connection.connected + queue.delete() + try: + queue.get() + except NotFound: + raise + except NotFound: + pass + + with pool.acquire(block=True) as connection: + assert not connection.connected + producer = kombu.Producer(connection) + queue = kombu.Queue( + "test-queue-{}".format(uuid.uuid4()), channel=connection + ) + queue.declare() + # In case the connection is broken, we should get a Timeout here + producer.publish( + {"foo": "bar"}, routing_key=str(uuid.uuid4()), retry=False, timeout=3 + ) diff --git a/t/unit/test_connection.py b/t/unit/test_connection.py index 9d1c699b6..4c13bb346 100644 --- a/t/unit/test_connection.py +++ b/t/unit/test_connection.py @@ -1045,6 +1045,43 @@ def test_acquire_channel(self): with P.acquire_channel() as (conn, channel): assert channel is conn.default_channel + def test_exception_during_connection_use(self): + """Tests that connections retrieved from a pool are replaced. + + In case of an exception during usage of an exception, it is required that the + connection is 'replaced' (effectively closing the connection) before releasing + it back into the pool. This ensures that reconnecting to the broker is required + before the next usage. + """ + P = self.create_resource(1) + + # Raising an exception during a network call should cause the cause the + # connection to be replaced. + with pytest.raises(IOError): + with P.acquire() as connection: + connection.connect() + connection.heartbeat_check = Mock() + connection.heartbeat_check.side_effect = IOError() + _ = connection.heartbeat_check() + + # Acquiring the same connection from the pool yields a disconnected Connection + # object. + with P.acquire() as connection: + assert not connection.connected + + # acquire_channel automatically reconnects + with pytest.raises(IOError): + with P.acquire_channel() as (connection, _): + # The Connection object should still be connected + assert connection.connected + connection.heartbeat_check = Mock() + connection.heartbeat_check.side_effect = IOError() + _ = connection.heartbeat_check() + + with P.acquire() as connection: + # The connection should be closed + assert not connection.connected + class test_ChannelPool(ResourceCase): diff --git a/t/unit/test_pools.py b/t/unit/test_pools.py index f4a3964f7..017f280da 100644 --- a/t/unit/test_pools.py +++ b/t/unit/test_pools.py @@ -5,7 +5,7 @@ import pytest from kombu import Connection, Producer, pools -from kombu.connection import ConnectionPool +from kombu.connection import ConnectionPool, PooledConnection from kombu.utils.collections import eqhash @@ -37,6 +37,18 @@ def test_releases_connection_when_Producer_raises(self): self.pool.create_producer() conn.release.assert_called_with() + def test_exception_during_connection_use(self): + """Tests that the connection is closed in case of an exception.""" + with pytest.raises(IOError): + with self.pool.acquire() as producer: + producer.__connection__ = Mock(spec=PooledConnection) + producer.__connection__._pool = self.connections + producer.publish = Mock() + producer.publish.side_effect = IOError() + producer.publish("test data") + + self.connections.replace.assert_called_once() + def test_prepare_release_connection_on_error(self): pp = Mock() p = pp.return_value = Mock() From c56800738dc6977bf4555883888905590e875515 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 7 Dec 2024 01:32:18 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- t/integration/test_py_amqp.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/t/integration/test_py_amqp.py b/t/integration/test_py_amqp.py index f66d3cd55..87ba6e3bc 100644 --- a/t/integration/test_py_amqp.py +++ b/t/integration/test_py_amqp.py @@ -4,20 +4,13 @@ import uuid import pytest - -import kombu from amqp.exceptions import NotFound +import kombu from kombu.connection import ConnectionPool -from .common import ( - BaseExchangeTypes, - BaseFailover, - BaseMessage, - BasePriority, - BaseTimeToLive, - BasicFunctionality, -) +from .common import (BaseExchangeTypes, BaseFailover, BaseMessage, + BasePriority, BaseTimeToLive, BasicFunctionality) def get_connection(hostname, port, vhost): @@ -115,7 +108,7 @@ def test_publish_confirm_does_not_block(self, confirm_publish_connection): In case an exception occurs while the connection is in use, the pool should close the exception. In case the connection is not closed before releasing it - back to the pool, the connection would remain in an unsuable state, causing + back to the pool, the connection would remain in an unusable state, causing causing the next publish call to time out or block forever in case no timeout is specified. """ @@ -125,7 +118,7 @@ def test_publish_confirm_does_not_block(self, confirm_publish_connection): with pool.acquire(block=True) as connection: producer = kombu.Producer(connection) queue = kombu.Queue( - "test-queue-{}".format(uuid.uuid4()), channel=connection + f"test-queue-{uuid.uuid4()}", channel=connection ) queue.declare() producer.publish( @@ -144,7 +137,7 @@ def test_publish_confirm_does_not_block(self, confirm_publish_connection): assert not connection.connected producer = kombu.Producer(connection) queue = kombu.Queue( - "test-queue-{}".format(uuid.uuid4()), channel=connection + f"test-queue-{uuid.uuid4()}", channel=connection ) queue.declare() # In case the connection is broken, we should get a Timeout here