From 012f7cf7d2950e8241eb5705e7cdd83b09370ba4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Mon, 18 Sep 2023 23:09:41 +0000 Subject: [PATCH] Add `Redis.from_pool()` class method, for explicitly owning and closing a ConnectionPool (#2913) * Add the `from_pool` argument to asyncio.Redis * Add tests for the `from_pool` argument * Add a "from_pool" argument for the sync client too * Add automatic connection pool close for redis.sentinel * use from_pool() class method instead * re-add the auto_close_connection_pool arg to Connection.from_url() * Deprecate the "auto_close_connection_pool" argument to Redis() and Redis.from_url() --- docs/examples/asyncio_examples.ipynb | 33 ++++++-- redis/asyncio/client.py | 65 ++++++++++++--- redis/asyncio/connection.py | 4 +- redis/asyncio/sentinel.py | 14 +--- redis/client.py | 29 ++++++- redis/sentinel.py | 12 +-- tests/test_asyncio/test_connection.py | 112 ++++++++++++++++++++++++-- tests/test_connection.py | 89 +++++++++++++++++++- tests/test_sentinel.py | 26 ++++++ 9 files changed, 334 insertions(+), 50 deletions(-) diff --git a/docs/examples/asyncio_examples.ipynb b/docs/examples/asyncio_examples.ipynb index f7e67e2ca7..a8ee8faf8b 100644 --- a/docs/examples/asyncio_examples.ipynb +++ b/docs/examples/asyncio_examples.ipynb @@ -44,6 +44,26 @@ "await connection.close()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If you create custom `ConnectionPool` for the `Redis` instance to use alone, use the `from_pool` class method to create it. This will cause the pool to be disconnected along with the Redis instance. Disconnecting the connection pool simply disconnects all connections hosted in the pool." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import redis.asyncio as redis\n", + "\n", + "pool = redis.ConnectionPool.from_url(\"redis://localhost\")\n", + "connection = redis.Redis.from_pool(pool)\n", + "await connection.close()" + ] + }, { "cell_type": "markdown", "metadata": { @@ -53,7 +73,8 @@ } }, "source": [ - "If you supply a custom `ConnectionPool` that is supplied to several `Redis` instances, you may want to disconnect the connection pool explicitly. Disconnecting the connection pool simply disconnects all connections hosted in the pool." + "\n", + "However, If you supply a `ConnectionPool` that is shared several `Redis` instances, you may want to disconnect the connection pool explicitly. use the `connection_pool` argument in that case." ] }, { @@ -69,10 +90,12 @@ "source": [ "import redis.asyncio as redis\n", "\n", - "connection = redis.Redis(auto_close_connection_pool=False)\n", - "await connection.close()\n", - "# Or: await connection.close(close_connection_pool=False)\n", - "await connection.connection_pool.disconnect()" + "pool = redis.ConnectionPool.from_url(\"redis://localhost\")\n", + "connection1 = redis.Redis(connection_pool=pool)\n", + "connection2 = redis.Redis(connection_pool=pool)\n", + "await connection1.close()\n", + "await connection2.close()\n", + "await pool.disconnect()" ] }, { diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index f0c1ab7536..cce1b5815d 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -114,7 +114,7 @@ def from_url( cls, url: str, single_connection_client: bool = False, - auto_close_connection_pool: bool = True, + auto_close_connection_pool: Optional[bool] = None, **kwargs, ): """ @@ -160,12 +160,39 @@ class initializer. In the case of conflicting arguments, querystring """ connection_pool = ConnectionPool.from_url(url, **kwargs) - redis = cls( + client = cls( connection_pool=connection_pool, single_connection_client=single_connection_client, ) - redis.auto_close_connection_pool = auto_close_connection_pool - return redis + if auto_close_connection_pool is not None: + warnings.warn( + DeprecationWarning( + '"auto_close_connection_pool" is deprecated ' + "since version 5.0.0. " + "Please create a ConnectionPool explicitly and " + "provide to the Redis() constructor instead." + ) + ) + else: + auto_close_connection_pool = True + client.auto_close_connection_pool = auto_close_connection_pool + return client + + @classmethod + def from_pool( + cls: Type["Redis"], + connection_pool: ConnectionPool, + ) -> "Redis": + """ + Return a Redis client from the given connection pool. + The Redis client will take ownership of the connection pool and + close it when the Redis client is closed. + """ + client = cls( + connection_pool=connection_pool, + ) + client.auto_close_connection_pool = True + return client def __init__( self, @@ -200,7 +227,8 @@ def __init__( lib_version: Optional[str] = get_lib_version(), username: Optional[str] = None, retry: Optional[Retry] = None, - auto_close_connection_pool: bool = True, + # deprecated. create a pool and use connection_pool instead + auto_close_connection_pool: Optional[bool] = None, redis_connect_func=None, credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, @@ -213,14 +241,21 @@ def __init__( To retry on TimeoutError, `retry_on_timeout` can also be set to `True`. """ kwargs: Dict[str, Any] - # auto_close_connection_pool only has an effect if connection_pool is - # None. This is a similar feature to the missing __del__ to resolve #1103, - # but it accounts for whether a user wants to manually close the connection - # pool, as a similar feature to ConnectionPool's __del__. - self.auto_close_connection_pool = ( - auto_close_connection_pool if connection_pool is None else False - ) + + if auto_close_connection_pool is not None: + warnings.warn( + DeprecationWarning( + '"auto_close_connection_pool" is deprecated ' + "since version 5.0.0. " + "Please create a ConnectionPool explicitly and " + "provide to the Redis() constructor instead." + ) + ) + else: + auto_close_connection_pool = True + if not connection_pool: + # Create internal connection pool, expected to be closed by Redis instance if not retry_on_error: retry_on_error = [] if retry_on_timeout is True: @@ -277,7 +312,13 @@ def __init__( "ssl_check_hostname": ssl_check_hostname, } ) + # This arg only used if no pool is passed in + self.auto_close_connection_pool = auto_close_connection_pool connection_pool = ConnectionPool(**kwargs) + else: + # If a pool is passed in, do not close it + self.auto_close_connection_pool = False + self.connection_pool = connection_pool self.single_connection_client = single_connection_client self.connection: Optional[Connection] = None diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 00865515ab..4f20e6ec1d 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -1106,8 +1106,8 @@ class BlockingConnectionPool(ConnectionPool): """ A blocking connection pool:: - >>> from redis.asyncio.client import Redis - >>> client = Redis(connection_pool=BlockingConnectionPool()) + >>> from redis.asyncio import Redis, BlockingConnectionPool + >>> client = Redis.from_pool(BlockingConnectionPool()) It performs the same function as the default :py:class:`~redis.asyncio.ConnectionPool` implementation, in that, diff --git a/redis/asyncio/sentinel.py b/redis/asyncio/sentinel.py index 20c0a1fdd5..6834fb194f 100644 --- a/redis/asyncio/sentinel.py +++ b/redis/asyncio/sentinel.py @@ -340,12 +340,7 @@ def master_for( connection_pool = connection_pool_class(service_name, self, **connection_kwargs) # The Redis object "owns" the pool - auto_close_connection_pool = True - client = redis_class( - connection_pool=connection_pool, - ) - client.auto_close_connection_pool = auto_close_connection_pool - return client + return redis_class.from_pool(connection_pool) def slave_for( self, @@ -377,9 +372,4 @@ def slave_for( connection_pool = connection_pool_class(service_name, self, **connection_kwargs) # The Redis object "owns" the pool - auto_close_connection_pool = True - client = redis_class( - connection_pool=connection_pool, - ) - client.auto_close_connection_pool = auto_close_connection_pool - return client + return redis_class.from_pool(connection_pool) diff --git a/redis/client.py b/redis/client.py index f695cef534..017f8ef11d 100755 --- a/redis/client.py +++ b/redis/client.py @@ -4,7 +4,7 @@ import time import warnings from itertools import chain -from typing import Optional +from typing import Optional, Type from redis._parsers.helpers import ( _RedisCallbacks, @@ -136,10 +136,28 @@ class initializer. In the case of conflicting arguments, querystring """ single_connection_client = kwargs.pop("single_connection_client", False) connection_pool = ConnectionPool.from_url(url, **kwargs) - return cls( + client = cls( connection_pool=connection_pool, single_connection_client=single_connection_client, ) + client.auto_close_connection_pool = True + return client + + @classmethod + def from_pool( + cls: Type["Redis"], + connection_pool: ConnectionPool, + ) -> "Redis": + """ + Return a Redis client from the given connection pool. + The Redis client will take ownership of the connection pool and + close it when the Redis client is closed. + """ + client = cls( + connection_pool=connection_pool, + ) + client.auto_close_connection_pool = True + return client def __init__( self, @@ -275,6 +293,10 @@ def __init__( } ) connection_pool = ConnectionPool(**kwargs) + self.auto_close_connection_pool = True + else: + self.auto_close_connection_pool = False + self.connection_pool = connection_pool self.connection = None if single_connection_client: @@ -477,6 +499,9 @@ def close(self): self.connection = None self.connection_pool.release(conn) + if self.auto_close_connection_pool: + self.connection_pool.disconnect() + def _send_command_parse_response(self, conn, command_name, *args, **options): """ Send a command and parse the response diff --git a/redis/sentinel.py b/redis/sentinel.py index 836e781e7f..41f308d1ee 100644 --- a/redis/sentinel.py +++ b/redis/sentinel.py @@ -353,10 +353,8 @@ def master_for( kwargs["is_master"] = True connection_kwargs = dict(self.connection_kwargs) connection_kwargs.update(kwargs) - return redis_class( - connection_pool=connection_pool_class( - service_name, self, **connection_kwargs - ) + return redis_class.from_pool( + connection_pool_class(service_name, self, **connection_kwargs) ) def slave_for( @@ -386,8 +384,6 @@ def slave_for( kwargs["is_master"] = False connection_kwargs = dict(self.connection_kwargs) connection_kwargs.update(kwargs) - return redis_class( - connection_pool=connection_pool_class( - service_name, self, **connection_kwargs - ) + return redis_class.from_pool( + connection_pool_class(service_name, self, **connection_kwargs) ) diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index 9ee6db1566..a954a40dbf 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -11,7 +11,7 @@ _AsyncRESP3Parser, _AsyncRESPBase, ) -from redis.asyncio import Redis +from redis.asyncio import ConnectionPool, Redis from redis.asyncio.connection import Connection, UnixDomainSocketConnection, parse_url from redis.asyncio.retry import Retry from redis.backoff import NoBackoff @@ -288,7 +288,7 @@ def test_create_single_connection_client_from_url(): assert client.single_connection_client is True -@pytest.mark.parametrize("from_url", (True, False)) +@pytest.mark.parametrize("from_url", (True, False), ids=("from_url", "from_args")) async def test_pool_auto_close(request, from_url): """Verify that basic Redis instances have auto_close_connection_pool set to True""" @@ -305,20 +305,116 @@ async def get_redis_connection(): await r1.close() -@pytest.mark.parametrize("from_url", (True, False)) -async def test_pool_auto_close_disable(request, from_url): - """Verify that auto_close_connection_pool can be disabled""" +async def test_pool_from_url_deprecation(request): + url: str = request.config.getoption("--redis-url") + + with pytest.deprecated_call(): + return Redis.from_url(url, auto_close_connection_pool=False) + + +async def test_pool_auto_close_disable(request): + """Verify that auto_close_connection_pool can be disabled (deprecated)""" url: str = request.config.getoption("--redis-url") url_args = parse_url(url) async def get_redis_connection(): - if from_url: - return Redis.from_url(url, auto_close_connection_pool=False) url_args["auto_close_connection_pool"] = False - return Redis(**url_args) + with pytest.deprecated_call(): + return Redis(**url_args) r1 = await get_redis_connection() assert r1.auto_close_connection_pool is False await r1.connection_pool.disconnect() await r1.close() + + +@pytest.mark.parametrize("from_url", (True, False), ids=("from_url", "from_args")) +async def test_redis_connection_pool(request, from_url): + """Verify that basic Redis instances using `connection_pool` + have auto_close_connection_pool set to False""" + + url: str = request.config.getoption("--redis-url") + url_args = parse_url(url) + + pool = None + + async def get_redis_connection(): + nonlocal pool + if from_url: + pool = ConnectionPool.from_url(url) + else: + pool = ConnectionPool(**url_args) + return Redis(connection_pool=pool) + + called = 0 + + async def mock_disconnect(_): + nonlocal called + called += 1 + + with patch.object(ConnectionPool, "disconnect", mock_disconnect): + async with await get_redis_connection() as r1: + assert r1.auto_close_connection_pool is False + + assert called == 0 + await pool.disconnect() + + +@pytest.mark.parametrize("from_url", (True, False), ids=("from_url", "from_args")) +async def test_redis_from_pool(request, from_url): + """Verify that basic Redis instances created using `from_pool()` + have auto_close_connection_pool set to True""" + + url: str = request.config.getoption("--redis-url") + url_args = parse_url(url) + + pool = None + + async def get_redis_connection(): + nonlocal pool + if from_url: + pool = ConnectionPool.from_url(url) + else: + pool = ConnectionPool(**url_args) + return Redis.from_pool(pool) + + called = 0 + + async def mock_disconnect(_): + nonlocal called + called += 1 + + with patch.object(ConnectionPool, "disconnect", mock_disconnect): + async with await get_redis_connection() as r1: + assert r1.auto_close_connection_pool is True + + assert called == 1 + await pool.disconnect() + + +@pytest.mark.parametrize("auto_close", (True, False)) +async def test_redis_pool_auto_close_arg(request, auto_close): + """test that redis instance where pool is provided have + auto_close_connection_pool set to False, regardless of arg""" + + url: str = request.config.getoption("--redis-url") + pool = ConnectionPool.from_url(url) + + async def get_redis_connection(): + with pytest.deprecated_call(): + client = Redis(connection_pool=pool, auto_close_connection_pool=auto_close) + return client + + called = 0 + + async def mock_disconnect(_): + nonlocal called + called += 1 + + with patch.object(ConnectionPool, "disconnect", mock_disconnect): + async with await get_redis_connection() as r1: + assert r1.auto_close_connection_pool is False + + assert called == 0 + await pool.disconnect() diff --git a/tests/test_connection.py b/tests/test_connection.py index 760b23c9c1..bff249559e 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -5,9 +5,15 @@ import pytest import redis +from redis import ConnectionPool, Redis from redis._parsers import _HiredisParser, _RESP2Parser, _RESP3Parser from redis.backoff import NoBackoff -from redis.connection import Connection, SSLConnection, UnixDomainSocketConnection +from redis.connection import ( + Connection, + SSLConnection, + UnixDomainSocketConnection, + parse_url, +) from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError from redis.retry import Retry from redis.utils import HIREDIS_AVAILABLE @@ -209,3 +215,84 @@ def test_create_single_connection_client_from_url(): "redis://localhost:6379/0?", single_connection_client=True ) assert client.connection is not None + + +@pytest.mark.parametrize("from_url", (True, False), ids=("from_url", "from_args")) +def test_pool_auto_close(request, from_url): + """Verify that basic Redis instances have auto_close_connection_pool set to True""" + + url: str = request.config.getoption("--redis-url") + url_args = parse_url(url) + + def get_redis_connection(): + if from_url: + return Redis.from_url(url) + return Redis(**url_args) + + r1 = get_redis_connection() + assert r1.auto_close_connection_pool is True + r1.close() + + +@pytest.mark.parametrize("from_url", (True, False), ids=("from_url", "from_args")) +def test_redis_connection_pool(request, from_url): + """Verify that basic Redis instances using `connection_pool` + have auto_close_connection_pool set to False""" + + url: str = request.config.getoption("--redis-url") + url_args = parse_url(url) + + pool = None + + def get_redis_connection(): + nonlocal pool + if from_url: + pool = ConnectionPool.from_url(url) + else: + pool = ConnectionPool(**url_args) + return Redis(connection_pool=pool) + + called = 0 + + def mock_disconnect(_): + nonlocal called + called += 1 + + with patch.object(ConnectionPool, "disconnect", mock_disconnect): + with get_redis_connection() as r1: + assert r1.auto_close_connection_pool is False + + assert called == 0 + pool.disconnect() + + +@pytest.mark.parametrize("from_url", (True, False), ids=("from_url", "from_args")) +def test_redis_from_pool(request, from_url): + """Verify that basic Redis instances created using `from_pool()` + have auto_close_connection_pool set to True""" + + url: str = request.config.getoption("--redis-url") + url_args = parse_url(url) + + pool = None + + def get_redis_connection(): + nonlocal pool + if from_url: + pool = ConnectionPool.from_url(url) + else: + pool = ConnectionPool(**url_args) + return Redis.from_pool(pool) + + called = 0 + + def mock_disconnect(_): + nonlocal called + called += 1 + + with patch.object(ConnectionPool, "disconnect", mock_disconnect): + with get_redis_connection() as r1: + assert r1.auto_close_connection_pool is True + + assert called == 1 + pool.disconnect() diff --git a/tests/test_sentinel.py b/tests/test_sentinel.py index b7bcc27de2..54b9647098 100644 --- a/tests/test_sentinel.py +++ b/tests/test_sentinel.py @@ -1,4 +1,5 @@ import socket +from unittest import mock import pytest import redis.sentinel @@ -240,3 +241,28 @@ def test_flushconfig(cluster, sentinel): def test_reset(cluster, sentinel): cluster.master["is_odown"] = True assert sentinel.sentinel_reset("mymaster") + + +@pytest.mark.onlynoncluster +@pytest.mark.parametrize("method_name", ["master_for", "slave_for"]) +def test_auto_close_pool(cluster, sentinel, method_name): + """ + Check that the connection pool created by the sentinel client is + automatically closed + """ + + method = getattr(sentinel, method_name) + client = method("mymaster", db=9) + pool = client.connection_pool + assert client.auto_close_connection_pool is True + calls = 0 + + def mock_disconnect(): + nonlocal calls + calls += 1 + + with mock.patch.object(pool, "disconnect", mock_disconnect): + client.close() + + assert calls == 1 + pool.disconnect()