From 60be299feab3f0dd0ca6d05e00bf5df88c5d561f Mon Sep 17 00:00:00 2001 From: MZauchner Date: Thu, 5 Dec 2024 21:04:20 +0100 Subject: [PATCH] Close connections in case of an exception When an exception occurs while the connection is in use, the connection may become unsuable. Such connections need to be closed, such that the connection to the broker can be reestablished at a later point. --- kombu/connection.py | 33 +++++++++++++++++- kombu/pools.py | 28 +++++++++++++++ t/integration/test_py_amqp.py | 64 +++++++++++++++++++++++++++++++++-- t/unit/test_connection.py | 27 ++++++++++++++- t/unit/test_pools.py | 10 ++++++ 5 files changed, 158 insertions(+), 4 deletions(-) diff --git a/kombu/connection.py b/kombu/connection.py index 32c416b00..38817a30e 100644 --- a/kombu/connection.py +++ b/kombu/connection.py @@ -1035,6 +1035,29 @@ def is_evented(self): BrokerConnection = Connection +class PooledConnection: + """Wraps :class:`kombu.Connection`. + + This wrapper modifies :meth:`kombu.Connection.__exit__` to close the connection + in case any exception occurred while the context was active. + """ + + def __init__(self, connection, pool): + self._connection = connection + self._pool = pool + + def __getattr__(self, attr): + return getattr(self._connection, attr) + + def __enter__(self): + return self._connection.__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is not None and self._pool.limit: + self._pool.replace(self._connection) + return self._connection.__exit__(exc_type, exc_val, exc_tb) + + class ConnectionPool(Resource): """Pool of connections.""" @@ -1061,10 +1084,18 @@ def collect_resource(self, resource, socket_timeout=0.1): if not isinstance(resource, lazy): return resource.collect(socket_timeout) + def acquire(self, block=False, timeout=None): + return PooledConnection(super().acquire(block=block, timeout=timeout), self) + @contextmanager def acquire_channel(self, block=False): with self.acquire(block=block) as connection: - yield connection, connection.default_channel + try: + yield connection, connection.default_channel + except BaseException: + if self.limit: + self.replace(connection) + raise def setup(self): if self.limit: diff --git a/kombu/pools.py b/kombu/pools.py index 106be1838..68b384a8a 100644 --- a/kombu/pools.py +++ b/kombu/pools.py @@ -23,6 +23,31 @@ def _after_fork_cleanup_group(group): group.clear() +class PooledProducer: + """Wraps :class:`kombu.Producer`. + + This wrapper modifies :meth:`kombu.Produer.__exit__` to close the connection + associated with the producer in case any exception occurred while the context + was active. + """ + + def __init__(self, producer, pool): + self._producer = producer + self._pool = pool + + def __getattr__(self, attr): + return getattr(self._connection, attr) + + def __enter__(self): + return self._producer.__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is not None: + if self._pool.limit and self._producer.__connection__ is not None: + self._pool.connections.replace(self._producer.__connection__) + return self._producer.__exit__(exc_type, exc_val, exc_tb) + + class ProducerPool(Resource): """Pool of :class:`kombu.Producer` instances.""" @@ -37,6 +62,9 @@ def __init__(self, connections, *args, **kwargs): def _acquire_connection(self): return self.connections.acquire(block=True) + def acquire(self, block=False, timeout=None): + return PooledProducer(super().acquire(block=block, timeout=timeout), self) + def create_producer(self): conn = self._acquire_connection() try: diff --git a/t/integration/test_py_amqp.py b/t/integration/test_py_amqp.py index 260f164d8..52682ce84 100644 --- a/t/integration/test_py_amqp.py +++ b/t/integration/test_py_amqp.py @@ -1,13 +1,23 @@ from __future__ import annotations import os +import uuid import pytest import kombu +from amqp.exceptions import NotFound -from .common import (BaseExchangeTypes, BaseFailover, BaseMessage, - BasePriority, BaseTimeToLive, BasicFunctionality) +from kombu.connection import ConnectionPool + +from .common import ( + BaseExchangeTypes, + BaseFailover, + BaseMessage, + BasePriority, + BaseTimeToLive, + BasicFunctionality, +) def get_connection(hostname, port, vhost): @@ -20,6 +30,12 @@ def get_failover_connection(hostname, port, vhost): ) +def get_confirm_connection(hostname, port): + return kombu.Connection( + f"pyamqp://{hostname}:{port}", transport_options={"confirm_publish": True} + ) + + @pytest.fixture() def invalid_connection(): return kombu.Connection('pyamqp://localhost:12345') @@ -47,6 +63,14 @@ def failover_connection(request): ) +@pytest.fixture() +def confirm_publish_connection(): + return get_confirm_connection( + hostname=os.environ.get("RABBITMQ_HOST", "localhost"), + port=os.environ.get("RABBITMQ_5672_TCP", "5672"), + ) + + @pytest.mark.env('py-amqp') @pytest.mark.flaky(reruns=5, reruns_delay=2) class test_PyAMQPBasicFunctionality(BasicFunctionality): @@ -81,3 +105,39 @@ class test_PyAMQPFailover(BaseFailover): @pytest.mark.flaky(reruns=5, reruns_delay=2) class test_PyAMQPMessage(BaseMessage): pass + + +@pytest.mark.env("py-amqp") +@pytest.mark.flaky(reruns=5, reruns_delay=2) +class test_PyAMQPConnectionPool: + def test_publish_confirm_does_not_block(self, confirm_publish_connection): + pool = ConnectionPool(connection=confirm_publish_connection, limit=2) + + try: + with pool.acquire(block=True) as connection: + producer = kombu.Producer(connection) + queue = kombu.Queue( + "test-queue-{}".format(uuid.uuid4()), channel=connection + ) + queue.declare() + producer.publish( + {"foo": "bar"}, routing_key=str(uuid.uuid4()), retry=False + ) + queue.delete() + try: + queue.get() + except NotFound: + raise + except NotFound: + pass + + with pool.acquire(block=True) as connection: + producer = kombu.Producer(connection) + queue = kombu.Queue( + "test-queue-{}".format(uuid.uuid4()), channel=connection + ) + queue.declare() + # In case the connection is broken, we should get a Timeout here + producer.publish( + {"foo": "bar"}, routing_key=str(uuid.uuid4()), retry=False, timeout=3 + ) diff --git a/t/unit/test_connection.py b/t/unit/test_connection.py index 9d1c699b6..aaad57ade 100644 --- a/t/unit/test_connection.py +++ b/t/unit/test_connection.py @@ -1022,7 +1022,7 @@ def test_acquire_raises_evaluated(self): with pytest.raises(MemoryError): with P.acquire(): assert False - P.release.assert_called_with(r) + P.release.assert_called_with(r._connection) def test_release_no__debug(self): P = self.create_resource(10) @@ -1045,6 +1045,31 @@ def test_acquire_channel(self): with P.acquire_channel() as (conn, channel): assert channel is conn.default_channel + def test_exception_during_connection_use(self): + P = self.create_resource(1) + P.replace = Mock() + + with pytest.raises(IOError): + with P.acquire() as connection: + connection.heartbeat_check = Mock() + connection.heartbeat_check.side_effect = IOError() + _ = connection.heartbeat_check() + + with P.acquire() as connection: + assert not connection.connected + P.replace.assert_called_once_with(connection) + + P = self.create_resource(1) + P.replace = Mock() + + with pytest.raises(IOError): + with P.acquire_channel() as (connection, _): + connection.heartbeat_check = Mock() + connection.heartbeat_check.side_effect = IOError() + _ = connection.heartbeat_check() + + P.replace.assert_called_with(connection) + class test_ChannelPool(ResourceCase): diff --git a/t/unit/test_pools.py b/t/unit/test_pools.py index f4a3964f7..e5593b5db 100644 --- a/t/unit/test_pools.py +++ b/t/unit/test_pools.py @@ -37,6 +37,16 @@ def test_releases_connection_when_Producer_raises(self): self.pool.create_producer() conn.release.assert_called_with() + def test_exception_during_connection_use(self): + with pytest.raises(IOError): + with self.pool.acquire() as producer: + producer.__connection__ = Mock() + producer.publish = Mock() + producer.publish.side_effect = IOError() + producer.publish("test data") + + self.connections.replace.assert_called_once() + def test_prepare_release_connection_on_error(self): pp = Mock() p = pp.return_value = Mock()