Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Close connections in case of an exception #2201

Merged
merged 2 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion kombu/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,6 +1035,26 @@ def is_evented(self):
BrokerConnection = Connection


class PooledConnection(Connection):
"""Wraps :class:`kombu.Connection`.

This wrapper modifies :meth:`kombu.Connection.__exit__` to close the connection
in case any exception occurred while the context was active.
"""

def __init__(self, pool, **kwargs):
self._pool = pool
super().__init__(**kwargs)

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None and self._pool.limit:
self._pool.replace(self)
return super().__exit__(exc_type, exc_val, exc_tb)


class ConnectionPool(Resource):
"""Pool of connections."""

Expand All @@ -1046,7 +1066,7 @@ def __init__(self, connection, limit=None, **kwargs):
super().__init__(limit=limit)

def new(self):
return self.connection.clone()
return PooledConnection(self, **dict(self.connection._info(resolve=False)))

def release_resource(self, resource):
try:
Expand Down
8 changes: 7 additions & 1 deletion kombu/messaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from .common import maybe_declare
from .compression import compress
from .connection import is_connection, maybe_channel
from .connection import PooledConnection, is_connection, maybe_channel
from .entity import Exchange, Queue, maybe_delivery_mode
from .exceptions import ContentDisallowed
from .serialization import dumps, prepare_accept_content
Expand Down Expand Up @@ -257,6 +257,12 @@ def __exit__(
exc_val: BaseException | None,
exc_tb: TracebackType | None
) -> None:
# In case the connection is part of a pool it needs to be
# replaced in case of an exception
if self.__connection__ is not None and exc_type is not None:
if isinstance(self.__connection__, PooledConnection):
self.__connection__._pool.replace(self.__connection__)

self.release()

def release(self):
Expand Down
63 changes: 63 additions & 0 deletions t/integration/test_py_amqp.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations

import os
import uuid

import pytest
from amqp.exceptions import NotFound

import kombu
from kombu.connection import ConnectionPool

from .common import (BaseExchangeTypes, BaseFailover, BaseMessage,
BasePriority, BaseTimeToLive, BasicFunctionality)
Expand All @@ -20,6 +23,12 @@ def get_failover_connection(hostname, port, vhost):
)


def get_confirm_connection(hostname, port):
return kombu.Connection(
f"pyamqp://{hostname}:{port}", transport_options={"confirm_publish": True}
)


@pytest.fixture()
def invalid_connection():
return kombu.Connection('pyamqp://localhost:12345')
Expand Down Expand Up @@ -47,6 +56,14 @@ def failover_connection(request):
)


@pytest.fixture()
def confirm_publish_connection():
return get_confirm_connection(
hostname=os.environ.get("RABBITMQ_HOST", "localhost"),
port=os.environ.get("RABBITMQ_5672_TCP", "5672"),
)


@pytest.mark.env('py-amqp')
@pytest.mark.flaky(reruns=5, reruns_delay=2)
class test_PyAMQPBasicFunctionality(BasicFunctionality):
Expand Down Expand Up @@ -81,3 +98,49 @@ class test_PyAMQPFailover(BaseFailover):
@pytest.mark.flaky(reruns=5, reruns_delay=2)
class test_PyAMQPMessage(BaseMessage):
pass


@pytest.mark.env("py-amqp")
@pytest.mark.flaky(reruns=5, reruns_delay=2)
class test_PyAMQPConnectionPool:
def test_publish_confirm_does_not_block(self, confirm_publish_connection):
"""Tests that the connection pool closes connections in case of an exception.

In case an exception occurs while the connection is in use, the pool should
close the exception. In case the connection is not closed before releasing it
back to the pool, the connection would remain in an unusable state, causing
causing the next publish call to time out or block forever in case no
timeout is specified.
"""
pool = ConnectionPool(connection=confirm_publish_connection, limit=1)

try:
with pool.acquire(block=True) as connection:
producer = kombu.Producer(connection)
queue = kombu.Queue(
f"test-queue-{uuid.uuid4()}", channel=connection
)
queue.declare()
producer.publish(
{"foo": "bar"}, routing_key=str(uuid.uuid4()), retry=False
)
assert connection.connected
queue.delete()
try:
queue.get()
except NotFound:
raise
except NotFound:
pass

with pool.acquire(block=True) as connection:
assert not connection.connected
producer = kombu.Producer(connection)
queue = kombu.Queue(
f"test-queue-{uuid.uuid4()}", channel=connection
)
queue.declare()
# In case the connection is broken, we should get a Timeout here
producer.publish(
{"foo": "bar"}, routing_key=str(uuid.uuid4()), retry=False, timeout=3
)
37 changes: 37 additions & 0 deletions t/unit/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,43 @@ def test_acquire_channel(self):
with P.acquire_channel() as (conn, channel):
assert channel is conn.default_channel

def test_exception_during_connection_use(self):
MZauchner marked this conversation as resolved.
Show resolved Hide resolved
"""Tests that connections retrieved from a pool are replaced.

In case of an exception during usage of an exception, it is required that the
connection is 'replaced' (effectively closing the connection) before releasing
it back into the pool. This ensures that reconnecting to the broker is required
before the next usage.
"""
P = self.create_resource(1)

# Raising an exception during a network call should cause the cause the
# connection to be replaced.
with pytest.raises(IOError):
with P.acquire() as connection:
connection.connect()
connection.heartbeat_check = Mock()
connection.heartbeat_check.side_effect = IOError()
_ = connection.heartbeat_check()

# Acquiring the same connection from the pool yields a disconnected Connection
# object.
with P.acquire() as connection:
assert not connection.connected

# acquire_channel automatically reconnects
with pytest.raises(IOError):
with P.acquire_channel() as (connection, _):
# The Connection object should still be connected
assert connection.connected
connection.heartbeat_check = Mock()
connection.heartbeat_check.side_effect = IOError()
_ = connection.heartbeat_check()

with P.acquire() as connection:
# The connection should be closed
assert not connection.connected


class test_ChannelPool(ResourceCase):

Expand Down
14 changes: 13 additions & 1 deletion t/unit/test_pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest

from kombu import Connection, Producer, pools
from kombu.connection import ConnectionPool
from kombu.connection import ConnectionPool, PooledConnection
from kombu.utils.collections import eqhash


Expand Down Expand Up @@ -37,6 +37,18 @@ def test_releases_connection_when_Producer_raises(self):
self.pool.create_producer()
conn.release.assert_called_with()

def test_exception_during_connection_use(self):
"""Tests that the connection is closed in case of an exception."""
with pytest.raises(IOError):
with self.pool.acquire() as producer:
producer.__connection__ = Mock(spec=PooledConnection)
producer.__connection__._pool = self.connections
producer.publish = Mock()
producer.publish.side_effect = IOError()
producer.publish("test data")

self.connections.replace.assert_called_once()

def test_prepare_release_connection_on_error(self):
pp = Mock()
p = pp.return_value = Mock()
Expand Down