Skip to content

Commit

Permalink
Close connections in case of an exception
Browse files Browse the repository at this point in the history
When an exception occurs while the connection is in use, the connection
may become unsuable. Such connections need to be closed, such that the
connection to the broker can be reestablished at a later point.
  • Loading branch information
MZauchner committed Dec 5, 2024
1 parent 1a5eb30 commit 60be299
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 4 deletions.
33 changes: 32 additions & 1 deletion kombu/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,6 +1035,29 @@ def is_evented(self):
BrokerConnection = Connection


class PooledConnection:
"""Wraps :class:`kombu.Connection`.
This wrapper modifies :meth:`kombu.Connection.__exit__` to close the connection
in case any exception occurred while the context was active.
"""

def __init__(self, connection, pool):
self._connection = connection
self._pool = pool

def __getattr__(self, attr):
return getattr(self._connection, attr)

def __enter__(self):
return self._connection.__enter__()

def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None and self._pool.limit:
self._pool.replace(self._connection)
return self._connection.__exit__(exc_type, exc_val, exc_tb)


class ConnectionPool(Resource):
"""Pool of connections."""

Expand All @@ -1061,10 +1084,18 @@ def collect_resource(self, resource, socket_timeout=0.1):
if not isinstance(resource, lazy):
return resource.collect(socket_timeout)

def acquire(self, block=False, timeout=None):
return PooledConnection(super().acquire(block=block, timeout=timeout), self)

@contextmanager
def acquire_channel(self, block=False):
with self.acquire(block=block) as connection:
yield connection, connection.default_channel
try:
yield connection, connection.default_channel
except BaseException:
if self.limit:
self.replace(connection)
raise

def setup(self):
if self.limit:
Expand Down
28 changes: 28 additions & 0 deletions kombu/pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,31 @@ def _after_fork_cleanup_group(group):
group.clear()


class PooledProducer:
"""Wraps :class:`kombu.Producer`.
This wrapper modifies :meth:`kombu.Produer.__exit__` to close the connection
associated with the producer in case any exception occurred while the context
was active.
"""

def __init__(self, producer, pool):
self._producer = producer
self._pool = pool

def __getattr__(self, attr):
return getattr(self._connection, attr)

def __enter__(self):
return self._producer.__enter__()

def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None:
if self._pool.limit and self._producer.__connection__ is not None:
self._pool.connections.replace(self._producer.__connection__)
return self._producer.__exit__(exc_type, exc_val, exc_tb)


class ProducerPool(Resource):
"""Pool of :class:`kombu.Producer` instances."""

Expand All @@ -37,6 +62,9 @@ def __init__(self, connections, *args, **kwargs):
def _acquire_connection(self):
return self.connections.acquire(block=True)

def acquire(self, block=False, timeout=None):
return PooledProducer(super().acquire(block=block, timeout=timeout), self)

def create_producer(self):
conn = self._acquire_connection()
try:
Expand Down
64 changes: 62 additions & 2 deletions t/integration/test_py_amqp.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
from __future__ import annotations

import os
import uuid

import pytest

import kombu
from amqp.exceptions import NotFound

from .common import (BaseExchangeTypes, BaseFailover, BaseMessage,
BasePriority, BaseTimeToLive, BasicFunctionality)
from kombu.connection import ConnectionPool

from .common import (
BaseExchangeTypes,
BaseFailover,
BaseMessage,
BasePriority,
BaseTimeToLive,
BasicFunctionality,
)


def get_connection(hostname, port, vhost):
Expand All @@ -20,6 +30,12 @@ def get_failover_connection(hostname, port, vhost):
)


def get_confirm_connection(hostname, port):
return kombu.Connection(
f"pyamqp://{hostname}:{port}", transport_options={"confirm_publish": True}
)


@pytest.fixture()
def invalid_connection():
return kombu.Connection('pyamqp://localhost:12345')
Expand Down Expand Up @@ -47,6 +63,14 @@ def failover_connection(request):
)


@pytest.fixture()
def confirm_publish_connection():
return get_confirm_connection(
hostname=os.environ.get("RABBITMQ_HOST", "localhost"),
port=os.environ.get("RABBITMQ_5672_TCP", "5672"),
)


@pytest.mark.env('py-amqp')
@pytest.mark.flaky(reruns=5, reruns_delay=2)
class test_PyAMQPBasicFunctionality(BasicFunctionality):
Expand Down Expand Up @@ -81,3 +105,39 @@ class test_PyAMQPFailover(BaseFailover):
@pytest.mark.flaky(reruns=5, reruns_delay=2)
class test_PyAMQPMessage(BaseMessage):
pass


@pytest.mark.env("py-amqp")
@pytest.mark.flaky(reruns=5, reruns_delay=2)
class test_PyAMQPConnectionPool:
def test_publish_confirm_does_not_block(self, confirm_publish_connection):
pool = ConnectionPool(connection=confirm_publish_connection, limit=2)

try:
with pool.acquire(block=True) as connection:
producer = kombu.Producer(connection)
queue = kombu.Queue(
"test-queue-{}".format(uuid.uuid4()), channel=connection
)
queue.declare()
producer.publish(
{"foo": "bar"}, routing_key=str(uuid.uuid4()), retry=False
)
queue.delete()
try:
queue.get()
except NotFound:
raise
except NotFound:
pass

with pool.acquire(block=True) as connection:
producer = kombu.Producer(connection)
queue = kombu.Queue(
"test-queue-{}".format(uuid.uuid4()), channel=connection
)
queue.declare()
# In case the connection is broken, we should get a Timeout here
producer.publish(
{"foo": "bar"}, routing_key=str(uuid.uuid4()), retry=False, timeout=3
)
27 changes: 26 additions & 1 deletion t/unit/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1022,7 +1022,7 @@ def test_acquire_raises_evaluated(self):
with pytest.raises(MemoryError):
with P.acquire():
assert False
P.release.assert_called_with(r)
P.release.assert_called_with(r._connection)

def test_release_no__debug(self):
P = self.create_resource(10)
Expand All @@ -1045,6 +1045,31 @@ def test_acquire_channel(self):
with P.acquire_channel() as (conn, channel):
assert channel is conn.default_channel

def test_exception_during_connection_use(self):
P = self.create_resource(1)
P.replace = Mock()

with pytest.raises(IOError):
with P.acquire() as connection:
connection.heartbeat_check = Mock()
connection.heartbeat_check.side_effect = IOError()
_ = connection.heartbeat_check()

with P.acquire() as connection:
assert not connection.connected
P.replace.assert_called_once_with(connection)

P = self.create_resource(1)
P.replace = Mock()

with pytest.raises(IOError):
with P.acquire_channel() as (connection, _):
connection.heartbeat_check = Mock()
connection.heartbeat_check.side_effect = IOError()
_ = connection.heartbeat_check()

P.replace.assert_called_with(connection)


class test_ChannelPool(ResourceCase):

Expand Down
10 changes: 10 additions & 0 deletions t/unit/test_pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@ def test_releases_connection_when_Producer_raises(self):
self.pool.create_producer()
conn.release.assert_called_with()

def test_exception_during_connection_use(self):
with pytest.raises(IOError):
with self.pool.acquire() as producer:
producer.__connection__ = Mock()
producer.publish = Mock()
producer.publish.side_effect = IOError()
producer.publish("test data")

self.connections.replace.assert_called_once()

def test_prepare_release_connection_on_error(self):
pp = Mock()
p = pp.return_value = Mock()
Expand Down

0 comments on commit 60be299

Please sign in to comment.