Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug: client side caching causes unexpected disconnections (async version) #3165

Merged
merged 10 commits into from
Feb 29, 2024
2 changes: 1 addition & 1 deletion redis/_parsers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ async def can_read_destructive(self) -> bool:
return True
try:
async with async_timeout(0):
return await self._stream.read(1)
return self._stream.at_eof()
except TimeoutError:
return False

Expand Down
4 changes: 3 additions & 1 deletion redis/_parsers/resp3.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,9 @@ async def _read_response(
)
for _ in range(int(response))
]
await self.handle_push_response(response, disable_decoding, push_request)
response = await self.handle_push_response(
response, disable_decoding, push_request
)
else:
raise InvalidResponse(f"Protocol Error: {raw!r}")

Expand Down
40 changes: 21 additions & 19 deletions redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,25 +629,27 @@ async def execute_command(self, *args, **options):
pool = self.connection_pool
conn = self.connection or await pool.get_connection(command_name, **options)
response_from_cache = await conn._get_from_local_cache(args)
if response_from_cache is not None:
return response_from_cache
else:
if self.single_connection_client:
await self._single_conn_lock.acquire()
try:
response = await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
),
lambda error: self._disconnect_raise(conn, error),
)
conn._add_to_local_cache(args, response, keys)
return response
finally:
if self.single_connection_client:
self._single_conn_lock.release()
if not self.connection:
await pool.release(conn)
try:
if response_from_cache is not None:
return response_from_cache
else:
try:
if self.single_connection_client:
await self._single_conn_lock.acquire()
response = await conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
),
lambda error: self._disconnect_raise(conn, error),
)
conn._add_to_local_cache(args, response, keys)
return response
finally:
if self.single_connection_client:
self._single_conn_lock.release()
finally:
if not self.connection:
await pool.release(conn)

async def parse_response(
self, connection: Connection, command_name: Union[str, bytes], **options
Expand Down
14 changes: 10 additions & 4 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,7 @@ def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes]

def _socket_is_empty(self):
"""Check if the socket is empty"""
return not self._reader.at_eof()
return len(self._reader._buffer) == 0
chayim marked this conversation as resolved.
Show resolved Hide resolved

def _cache_invalidation_process(
self, data: List[Union[str, Optional[List[str]]]]
Expand Down Expand Up @@ -1192,12 +1192,18 @@ def make_connection(self):
async def ensure_connection(self, connection: AbstractConnection):
"""Ensure that the connection object is connected and valid"""
await connection.connect()
# connections that the pool provides should be ready to send
# a command. if not, the connection was either returned to the
# if client caching is not enabled connections that the pool
# provides should be ready to send a command.
# if not, the connection was either returned to the
# pool before all data has been read or the socket has been
# closed. either way, reconnect and verify everything is good.
# (if caching enabled the connection will not always be ready
# to send a command because it may contain invalidation messages)
try:
if await connection.can_read_destructive():
if (
await connection.can_read_destructive()
and connection.client_cache is None
):
raise ConnectionError("Connection has data") from None
except (ConnectionError, OSError):
await connection.disconnect()
Expand Down
42 changes: 42 additions & 0 deletions tests/test_asyncio/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,48 @@ async def test_cache_return_copy(self, r):
check = cache.get(("LRANGE", "mylist", 0, -1))
assert check == [b"baz", b"bar", b"foo"]

@pytest.mark.onlynoncluster
@pytest.mark.parametrize(
"r",
[{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}],
indirect=True,
)
async def test_csc_not_cause_disconnects(self, r):
r, cache = r
id1 = await r.client_id()
await r.mset({"a": 1, "b": 1, "c": 1, "d": 1, "e": 1})
assert await r.mget("a", "b", "c", "d", "e") == ["1", "1", "1", "1", "1"]
id2 = await r.client_id()

# client should get value from client cache
assert await r.mget("a", "b", "c", "d", "e") == ["1", "1", "1", "1", "1"]
assert cache.get(("MGET", "a", "b", "c", "d", "e")) == [
"1",
"1",
"1",
"1",
"1",
]

await r.mset({"a": 2, "b": 2, "c": 2, "d": 2, "e": 2})
id3 = await r.client_id()
# client should get value from redis server post invalidate messages
assert await r.mget("a", "b", "c", "d", "e") == ["2", "2", "2", "2", "2"]

await r.mset({"a": 3, "b": 3, "c": 3, "d": 3, "e": 3})
# need to check that we get correct value 3 and not 2
assert await r.mget("a", "b", "c", "d", "e") == ["3", "3", "3", "3", "3"]
# client should get value from client cache
assert await r.mget("a", "b", "c", "d", "e") == ["3", "3", "3", "3", "3"]

await r.mset({"a": 4, "b": 4, "c": 4, "d": 4, "e": 4})
# need to check that we get correct value 4 and not 3
assert await r.mget("a", "b", "c", "d", "e") == ["4", "4", "4", "4", "4"]
# client should get value from client cache
assert await r.mget("a", "b", "c", "d", "e") == ["4", "4", "4", "4", "4"]
id4 = await r.client_id()
assert id1 == id2 == id3 == id4


@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only")
@pytest.mark.onlycluster
Expand Down
Loading