diff --git a/redis/_parsers/base.py b/redis/_parsers/base.py index 8e59249bef..0137539d66 100644 --- a/redis/_parsers/base.py +++ b/redis/_parsers/base.py @@ -182,7 +182,7 @@ async def can_read_destructive(self) -> bool: return True try: async with async_timeout(0): - return await self._stream.read(1) + return self._stream.at_eof() except TimeoutError: return False diff --git a/redis/_parsers/resp3.py b/redis/_parsers/resp3.py index 88c8d5e52b..7afa43a0c2 100644 --- a/redis/_parsers/resp3.py +++ b/redis/_parsers/resp3.py @@ -261,7 +261,9 @@ async def _read_response( ) for _ in range(int(response)) ] - await self.handle_push_response(response, disable_decoding, push_request) + response = await self.handle_push_response( + response, disable_decoding, push_request + ) else: raise InvalidResponse(f"Protocol Error: {raw!r}") diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index a8942c9160..e5c99644a8 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -625,25 +625,27 @@ async def execute_command(self, *args, **options): pool = self.connection_pool conn = self.connection or await pool.get_connection(command_name, **options) response_from_cache = await conn._get_from_local_cache(args) - if response_from_cache is not None: - return response_from_cache - else: - if self.single_connection_client: - await self._single_conn_lock.acquire() - try: - response = await conn.retry.call_with_retry( - lambda: self._send_command_parse_response( - conn, command_name, *args, **options - ), - lambda error: self._disconnect_raise(conn, error), - ) - conn._add_to_local_cache(args, response, keys) - return response - finally: - if self.single_connection_client: - self._single_conn_lock.release() - if not self.connection: - await pool.release(conn) + try: + if response_from_cache is not None: + return response_from_cache + else: + try: + if self.single_connection_client: + await self._single_conn_lock.acquire() + response = await conn.retry.call_with_retry( + lambda: self._send_command_parse_response( + conn, command_name, *args, **options + ), + lambda error: self._disconnect_raise(conn, error), + ) + conn._add_to_local_cache(args, response, keys) + return response + finally: + if self.single_connection_client: + self._single_conn_lock.release() + finally: + if not self.connection: + await pool.release(conn) async def parse_response( self, connection: Connection, command_name: Union[str, bytes], **options diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 266e428ebb..31ce7db4df 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -684,7 +684,7 @@ def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes] def _socket_is_empty(self): """Check if the socket is empty""" - return not self._reader.at_eof() + return len(self._reader._buffer) == 0 def _cache_invalidation_process( self, data: List[Union[str, Optional[List[str]]]] @@ -1191,12 +1191,18 @@ def make_connection(self): async def ensure_connection(self, connection: AbstractConnection): """Ensure that the connection object is connected and valid""" await connection.connect() - # connections that the pool provides should be ready to send - # a command. if not, the connection was either returned to the + # if client caching is not enabled connections that the pool + # provides should be ready to send a command. + # if not, the connection was either returned to the # pool before all data has been read or the socket has been # closed. either way, reconnect and verify everything is good. + # (if caching enabled the connection will not always be ready + # to send a command because it may contain invalidation messages) try: - if await connection.can_read_destructive(): + if ( + await connection.can_read_destructive() + and connection.client_cache is None + ): raise ConnectionError("Connection has data") from None except (ConnectionError, OSError): await connection.disconnect() diff --git a/tests/test_asyncio/test_cache.py b/tests/test_asyncio/test_cache.py index bf20337dfb..4762bb7c05 100644 --- a/tests/test_asyncio/test_cache.py +++ b/tests/test_asyncio/test_cache.py @@ -142,6 +142,48 @@ async def test_cache_return_copy(self, r): check = cache.get(("LRANGE", "mylist", 0, -1)) assert check == [b"baz", b"bar", b"foo"] + @pytest.mark.onlynoncluster + @pytest.mark.parametrize( + "r", + [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], + indirect=True, + ) + async def test_csc_not_cause_disconnects(self, r): + r, cache = r + id1 = await r.client_id() + await r.mset({"a": 1, "b": 1, "c": 1, "d": 1, "e": 1}) + assert await r.mget("a", "b", "c", "d", "e") == ["1", "1", "1", "1", "1"] + id2 = await r.client_id() + + # client should get value from client cache + assert await r.mget("a", "b", "c", "d", "e") == ["1", "1", "1", "1", "1"] + assert cache.get(("MGET", "a", "b", "c", "d", "e")) == [ + "1", + "1", + "1", + "1", + "1", + ] + + await r.mset({"a": 2, "b": 2, "c": 2, "d": 2, "e": 2}) + id3 = await r.client_id() + # client should get value from redis server post invalidate messages + assert await r.mget("a", "b", "c", "d", "e") == ["2", "2", "2", "2", "2"] + + await r.mset({"a": 3, "b": 3, "c": 3, "d": 3, "e": 3}) + # need to check that we get correct value 3 and not 2 + assert await r.mget("a", "b", "c", "d", "e") == ["3", "3", "3", "3", "3"] + # client should get value from client cache + assert await r.mget("a", "b", "c", "d", "e") == ["3", "3", "3", "3", "3"] + + await r.mset({"a": 4, "b": 4, "c": 4, "d": 4, "e": 4}) + # need to check that we get correct value 4 and not 3 + assert await r.mget("a", "b", "c", "d", "e") == ["4", "4", "4", "4", "4"] + # client should get value from client cache + assert await r.mget("a", "b", "c", "d", "e") == ["4", "4", "4", "4", "4"] + id4 = await r.client_id() + assert id1 == id2 == id3 == id4 + @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @pytest.mark.onlycluster