From 0be67bfdf609943e6ce8ecdb026c3234753bbfec Mon Sep 17 00:00:00 2001 From: Gabriel Erzse Date: Thu, 4 Jul 2024 18:00:48 +0300 Subject: [PATCH] Format connection errors in the same way everywhere (#3305) Connection errors are formatted in four places, sync and async, network socket and unix socket. Each place has some small differences compared to the others, while they could be, and should be, formatted in an uniform way. Factor out the logic in a helper method and call that method in all four places. Arguably we lose some specificity, e.g. the words "unix socket" won't be there anymore, but it is more valuable to not have code duplication. --- redis/asyncio/connection.py | 40 ++------------------- redis/connection.py | 39 ++------------------ redis/utils.py | 12 +++++++ tests/test_asyncio/test_connection.py | 51 +++++++++++++++++++++++---- tests/test_connection.py | 50 ++++++++++++++++++++++++++ 5 files changed, 111 insertions(+), 81 deletions(-) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index b3dae2a27b..ec1ce5a915 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -27,6 +27,8 @@ ) from urllib.parse import ParseResult, parse_qs, unquote, urlparse +from ..utils import format_error_message + # the functionality is available in 3.11.x but has a major issue before # 3.11.3. See https://github.com/redis/redis-py/issues/2633 if sys.version_info >= (3, 11, 3): @@ -345,9 +347,8 @@ async def _connect(self): def _host_error(self) -> str: pass - @abstractmethod def _error_message(self, exception: BaseException) -> str: - pass + return format_error_message(self._host_error(), exception) async def on_connect(self) -> None: """Initialize the connection, authenticate and select a database""" @@ -799,27 +800,6 @@ async def _connect(self): def _host_error(self) -> str: return f"{self.host}:{self.port}" - def _error_message(self, exception: BaseException) -> str: - # args for socket.error can either be (errno, "message") - # or just "message" - - host_error = self._host_error() - - if not exception.args: - # asyncio has a bug where on Connection reset by peer, the - # exception is not instanciated, so args is empty. This is the - # workaround. - # See: https://github.com/redis/redis-py/issues/2237 - # See: https://github.com/python/cpython/issues/94061 - return f"Error connecting to {host_error}. Connection reset by peer" - elif len(exception.args) == 1: - return f"Error connecting to {host_error}. {exception.args[0]}." - else: - return ( - f"Error {exception.args[0]} connecting to {host_error}. " - f"{exception}." - ) - class SSLConnection(Connection): """Manages SSL connections to and from the Redis server(s). @@ -971,20 +951,6 @@ async def _connect(self): def _host_error(self) -> str: return self.path - def _error_message(self, exception: BaseException) -> str: - # args for socket.error can either be (errno, "message") - # or just "message" - host_error = self._host_error() - if len(exception.args) == 1: - return ( - f"Error connecting to unix socket: {host_error}. {exception.args[0]}." - ) - else: - return ( - f"Error {exception.args[0]} connecting to unix socket: " - f"{host_error}. {exception.args[1]}." - ) - FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO") diff --git a/redis/connection.py b/redis/connection.py index 728c221257..6e3b3ab081 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -39,6 +39,7 @@ HIREDIS_AVAILABLE, HIREDIS_PACK_AVAILABLE, SSL_AVAILABLE, + format_error_message, get_lib_version, str_if_bytes, ) @@ -338,9 +339,8 @@ def _connect(self): def _host_error(self): pass - @abstractmethod def _error_message(self, exception): - pass + return format_error_message(self._host_error(), exception) def on_connect(self): "Initialize the connection, authenticate and select a database" @@ -733,27 +733,6 @@ def _connect(self): def _host_error(self): return f"{self.host}:{self.port}" - def _error_message(self, exception): - # args for socket.error can either be (errno, "message") - # or just "message" - - host_error = self._host_error() - - if len(exception.args) == 1: - try: - return f"Error connecting to {host_error}. \ - {exception.args[0]}." - except AttributeError: - return f"Connection Error: {exception.args[0]}" - else: - try: - return ( - f"Error {exception.args[0]} connecting to " - f"{host_error}. {exception.args[1]}." - ) - except AttributeError: - return f"Connection Error: {exception.args[0]}" - class SSLConnection(Connection): """Manages SSL connections to and from the Redis server(s). @@ -930,20 +909,6 @@ def _connect(self): def _host_error(self): return self.path - def _error_message(self, exception): - # args for socket.error can either be (errno, "message") - # or just "message" - host_error = self._host_error() - if len(exception.args) == 1: - return ( - f"Error connecting to unix socket: {host_error}. {exception.args[0]}." - ) - else: - return ( - f"Error {exception.args[0]} connecting to unix socket: " - f"{host_error}. {exception.args[1]}." - ) - FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO") diff --git a/redis/utils.py b/redis/utils.py index ea2eac149e..360ee54b8c 100644 --- a/redis/utils.py +++ b/redis/utils.py @@ -141,3 +141,15 @@ def get_lib_version(): except metadata.PackageNotFoundError: libver = "99.99.99" return libver + + +def format_error_message(host_error: str, exception: BaseException) -> str: + if not exception.args: + return f"Error connecting to {host_error}." + elif len(exception.args) == 1: + return f"Error {exception.args[0]} connecting to {host_error}." + else: + return ( + f"Error {exception.args[0]} connecting to {host_error}. " + f"{exception.args[1]}." + ) diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index 6255ae7d6d..8f79f7d947 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -12,7 +12,12 @@ _AsyncRESPBase, ) from redis.asyncio import ConnectionPool, Redis -from redis.asyncio.connection import Connection, UnixDomainSocketConnection, parse_url +from redis.asyncio.connection import ( + Connection, + SSLConnection, + UnixDomainSocketConnection, + parse_url, +) from redis.asyncio.retry import Retry from redis.backoff import NoBackoff from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError @@ -494,18 +499,50 @@ async def test_connection_garbage_collection(request): @pytest.mark.parametrize( - "error, expected_message", + "conn, error, expected_message", [ - (OSError(), "Error connecting to localhost:6379. Connection reset by peer"), - (OSError(12), "Error connecting to localhost:6379. 12."), + (SSLConnection(), OSError(), "Error connecting to localhost:6379."), + (SSLConnection(), OSError(12), "Error 12 connecting to localhost:6379."), ( + SSLConnection(), OSError(12, "Some Error"), - "Error 12 connecting to localhost:6379. [Errno 12] Some Error.", + "Error 12 connecting to localhost:6379. Some Error.", + ), + ( + UnixDomainSocketConnection(path="unix:///tmp/redis.sock"), + OSError(), + "Error connecting to unix:///tmp/redis.sock.", + ), + ( + UnixDomainSocketConnection(path="unix:///tmp/redis.sock"), + OSError(12), + "Error 12 connecting to unix:///tmp/redis.sock.", + ), + ( + UnixDomainSocketConnection(path="unix:///tmp/redis.sock"), + OSError(12, "Some Error"), + "Error 12 connecting to unix:///tmp/redis.sock. Some Error.", ), ], ) -async def test_connect_error_message(error, expected_message): +async def test_format_error_message(conn, error, expected_message): """Test that the _error_message function formats errors correctly""" - conn = Connection() error_message = conn._error_message(error) assert error_message == expected_message + + +async def test_network_connection_failure(): + with pytest.raises(ConnectionError) as e: + redis = Redis(host="127.0.0.1", port=9999) + await redis.set("a", "b") + assert str(e.value).startswith("Error 111 connecting to 127.0.0.1:9999. Connect") + + +async def test_unix_socket_connection_failure(): + with pytest.raises(ConnectionError) as e: + redis = Redis(unix_socket_path="unix:///tmp/a.sock") + await redis.set("a", "b") + assert ( + str(e.value) + == "Error 2 connecting to unix:///tmp/a.sock. No such file or directory." + ) diff --git a/tests/test_connection.py b/tests/test_connection.py index bff249559e..69275d58c0 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -296,3 +296,53 @@ def mock_disconnect(_): assert called == 1 pool.disconnect() + + +@pytest.mark.parametrize( + "conn, error, expected_message", + [ + (SSLConnection(), OSError(), "Error connecting to localhost:6379."), + (SSLConnection(), OSError(12), "Error 12 connecting to localhost:6379."), + ( + SSLConnection(), + OSError(12, "Some Error"), + "Error 12 connecting to localhost:6379. Some Error.", + ), + ( + UnixDomainSocketConnection(path="unix:///tmp/redis.sock"), + OSError(), + "Error connecting to unix:///tmp/redis.sock.", + ), + ( + UnixDomainSocketConnection(path="unix:///tmp/redis.sock"), + OSError(12), + "Error 12 connecting to unix:///tmp/redis.sock.", + ), + ( + UnixDomainSocketConnection(path="unix:///tmp/redis.sock"), + OSError(12, "Some Error"), + "Error 12 connecting to unix:///tmp/redis.sock. Some Error.", + ), + ], +) +def test_format_error_message(conn, error, expected_message): + """Test that the _error_message function formats errors correctly""" + error_message = conn._error_message(error) + assert error_message == expected_message + + +def test_network_connection_failure(): + with pytest.raises(ConnectionError) as e: + redis = Redis(port=9999) + redis.set("a", "b") + assert str(e.value) == "Error 111 connecting to localhost:9999. Connection refused." + + +def test_unix_socket_connection_failure(): + with pytest.raises(ConnectionError) as e: + redis = Redis(unix_socket_path="unix:///tmp/a.sock") + redis.set("a", "b") + assert ( + str(e.value) + == "Error 2 connecting to unix:///tmp/a.sock. No such file or directory." + )