Skip to content

Commit

Permalink
Fix bug: client side caching causes unexpected disconnections (#3160)
Browse files Browse the repository at this point in the history
* fix disconnects

* skip test in cluster

---------

Co-authored-by: Chayim <[email protected]>
  • Loading branch information
2 people authored and vladvildanov committed Sep 27, 2024
1 parent c32ac68 commit 075fba7
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 19 deletions.
4 changes: 3 additions & 1 deletion redis/_parsers/resp3.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,9 @@ def _read_response(self, disable_decoding=False, push_request=False):
)
for _ in range(int(response))
]
self.handle_push_response(response, disable_decoding, push_request)
response = self.handle_push_response(
response, disable_decoding, push_request
)
else:
raise InvalidResponse(f"Protocol Error: {raw!r}")

Expand Down
14 changes: 7 additions & 7 deletions redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,10 +560,10 @@ def execute_command(self, *args, **options):
pool = self.connection_pool
conn = self.connection or pool.get_connection(command_name, **options)
response_from_cache = conn._get_from_local_cache(args)
if response_from_cache is not None:
return response_from_cache
else:
try:
try:
if response_from_cache is not None:
return response_from_cache
else:
response = conn.retry.call_with_retry(
lambda: self._send_command_parse_response(
conn, command_name, *args, **options
Expand All @@ -572,9 +572,9 @@ def execute_command(self, *args, **options):
)
conn._add_to_local_cache(args, response, keys)
return response
finally:
if not self.connection:
pool.release(conn)
finally:
if not self.connection:
pool.release(conn)

def parse_response(self, connection, command_name, **options):
"""Parses a response from the Redis server"""
Expand Down
2 changes: 1 addition & 1 deletion redis/commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2011,7 +2011,7 @@ def mget(self, keys: KeysT, *args: EncodableT) -> ResponseT:
options = {}
if not args:
options[EMPTY_RESPONSE] = []
options["keys"] = keys
options["keys"] = args
return self.execute_command("MGET", *args, **options)

def mset(self, mapping: Mapping[AnyKeyT, EncodableT]) -> ResponseT:
Expand Down
17 changes: 7 additions & 10 deletions redis/connection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import copy
import os
import select
import socket
import ssl
import sys
Expand Down Expand Up @@ -609,11 +608,6 @@ def pack_commands(self, commands):
output.append(SYM_EMPTY.join(pieces))
return output

def _socket_is_empty(self):
"""Check if the socket is empty"""
r, _, _ = select.select([self._sock], [], [], 0)
return not bool(r)

def _cache_invalidation_process(
self, data: List[Union[str, Optional[List[str]]]]
) -> None:
Expand All @@ -639,7 +633,7 @@ def _get_from_local_cache(self, command: str):
or command[0] not in self.cache_whitelist
):
return None
while not self._socket_is_empty():
while self.can_read():
self.read_response(push_request=True)
return self.client_cache.get(command)

Expand Down Expand Up @@ -1187,12 +1181,15 @@ def get_connection(self, command_name: str, *keys, **options) -> "Connection":
try:
# ensure this connection is connected to Redis
connection.connect()
# connections that the pool provides should be ready to send
# a command. if not, the connection was either returned to the
# if client caching is not enabled connections that the pool
# provides should be ready to send a command.
# if not, the connection was either returned to the
# pool before all data has been read or the socket has been
# closed. either way, reconnect and verify everything is good.
# (if caching enabled the connection will not always be ready
# to send a command because it may contain invalidation messages)
try:
if connection.can_read():
if connection.can_read() and connection.client_cache is None:
raise ConnectionError("Connection has data")
except (ConnectionError, OSError):
connection.disconnect()
Expand Down
43 changes: 43 additions & 0 deletions tests/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,49 @@ def test_cache_return_copy(self, r):
check = cache.get(("LRANGE", "mylist", 0, -1))
assert check == [b"baz", b"bar", b"foo"]

@pytest.mark.onlynoncluster
@pytest.mark.parametrize(
"r",
[{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}],
indirect=True,
)
def test_csc_not_cause_disconnects(self, r):
r, cache = r
id1 = r.client_id()
r.mset({"a": 1, "b": 1, "c": 1, "d": 1, "e": 1, "f": 1})
assert r.mget("a", "b", "c", "d", "e", "f") == ["1", "1", "1", "1", "1", "1"]
id2 = r.client_id()

# client should get value from client cache
assert r.mget("a", "b", "c", "d", "e", "f") == ["1", "1", "1", "1", "1", "1"]
assert cache.get(("MGET", "a", "b", "c", "d", "e", "f")) == [
"1",
"1",
"1",
"1",
"1",
"1",
]

r.mset({"a": 2, "b": 2, "c": 2, "d": 2, "e": 2, "f": 2})
id3 = r.client_id()
# client should get value from redis server post invalidate messages
assert r.mget("a", "b", "c", "d", "e", "f") == ["2", "2", "2", "2", "2", "2"]

r.mset({"a": 3, "b": 3, "c": 3, "d": 3, "e": 3, "f": 3})
# need to check that we get correct value 3 and not 2
assert r.mget("a", "b", "c", "d", "e", "f") == ["3", "3", "3", "3", "3", "3"]
# client should get value from client cache
assert r.mget("a", "b", "c", "d", "e", "f") == ["3", "3", "3", "3", "3", "3"]

r.mset({"a": 4, "b": 4, "c": 4, "d": 4, "e": 4, "f": 4})
# need to check that we get correct value 4 and not 3
assert r.mget("a", "b", "c", "d", "e", "f") == ["4", "4", "4", "4", "4", "4"]
# client should get value from client cache
assert r.mget("a", "b", "c", "d", "e", "f") == ["4", "4", "4", "4", "4", "4"]
id4 = r.client_id()
assert id1 == id2 == id3 == id4


@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only")
@pytest.mark.onlycluster
Expand Down

0 comments on commit 075fba7

Please sign in to comment.