From 2e28576b777b7e8988be8b462fc50ac074b357d6 Mon Sep 17 00:00:00 2001 From: Brian Sam-Bodden Date: Tue, 28 May 2024 17:54:35 -0700 Subject: [PATCH] Fix vector type fields should not be encoded as strings (#2772) --- CHANGES | 3 ++- dev_requirements.txt | 1 + redis/_parsers/encoders.py | 6 ++++- redis/_parsers/hiredis.py | 5 +++- redis/commands/search/result.py | 17 ++++--------- tests/test_search.py | 45 +++++++++++++++++++++++++++++++++ 6 files changed, 62 insertions(+), 15 deletions(-) diff --git a/CHANGES b/CHANGES index 82c5b6db2a..2025165f65 100644 --- a/CHANGES +++ b/CHANGES @@ -1,3 +1,4 @@ + * Fix #2772, Fix vector type fields should not be encoded as strings * Update `ResponseT` type hint * Allow to control the minimum SSL version * Add an optional lock_name attribute to LockError. @@ -59,7 +60,7 @@ * Fix Sentinel.execute_command doesn't execute across the entire sentinel cluster bug (#2458) * Added a replacement for the default cluster node in the event of failure (#2463) * Fix for Unhandled exception related to self.host with unix socket (#2496) - * Improve error output for master discovery + * Improve error output for master discovery * Make `ClusterCommandsProtocol` an actual Protocol * Add `sum` to DUPLICATE_POLICY documentation of `TS.CREATE`, `TS.ADD` and `TS.ALTER` * Prevent async ClusterPipeline instances from becoming "false-y" in case of empty command stack (#3061) diff --git a/dev_requirements.txt b/dev_requirements.txt index 48ec278d83..2357115934 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -16,3 +16,4 @@ ujson>=4.2.0 wheel>=0.30.0 urllib3<2 uvloop +numpy>=1.24.4 diff --git a/redis/_parsers/encoders.py b/redis/_parsers/encoders.py index 6fdf0ad882..b77ba7feff 100644 --- a/redis/_parsers/encoders.py +++ b/redis/_parsers/encoders.py @@ -40,5 +40,9 @@ def decode(self, value, force=False): if isinstance(value, memoryview): value = value.tobytes() if isinstance(value, bytes): - value = value.decode(self.encoding, self.encoding_errors) + try: + value = value.decode(self.encoding, self.encoding_errors) + except UnicodeDecodeError: + # Return the bytes unmodified + return value return value diff --git a/redis/_parsers/hiredis.py b/redis/_parsers/hiredis.py index a52dbbd013..3b5a847427 100644 --- a/redis/_parsers/hiredis.py +++ b/redis/_parsers/hiredis.py @@ -123,7 +123,10 @@ def read_response(self, disable_decoding=False): if disable_decoding: response = self._reader.gets(False) else: - response = self._reader.gets() + try: + response = self._reader.gets() + except UnicodeDecodeError: + response = self._reader.gets(False) # if the response is a ConnectionError or the response is a list and # the first item is a ConnectionError, raise it as something bad # happened diff --git a/redis/commands/search/result.py b/redis/commands/search/result.py index 5b19e6faa4..29d63e86f7 100644 --- a/redis/commands/search/result.py +++ b/redis/commands/search/result.py @@ -39,18 +39,11 @@ def __init__( fields = {} if hascontent and res[i + fields_offset] is not None: - fields = ( - dict( - dict( - zip( - map(to_string, res[i + fields_offset][::2]), - map(to_string, res[i + fields_offset][1::2]), - ) - ) - ) - if hascontent - else {} - ) + for j in range(0, len(res[i + fields_offset]), 2): + key = to_string(res[i + fields_offset][j]) + value = res[i + fields_offset][j + 1] + fields[key] = value + try: del fields["id"] except KeyError: diff --git a/tests/test_search.py b/tests/test_search.py index f19c193891..bde10b4eef 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -4,6 +4,7 @@ import time from io import TextIOWrapper +import numpy as np import pytest import redis import redis.commands.search @@ -2282,3 +2283,47 @@ def test_geoshape(client: redis.Redis): assert result.docs[0]["id"] == "small" result = client.ft().search(q2, query_params=qp2) assert len(result.docs) == 2 + + +@pytest.mark.redismod +def test_vector_storage_and_retrieval(client): + # Constants + INDEX_NAME = "vector_index" + DOC_PREFIX = "doc:" + VECTOR_DIMENSIONS = 4 + VECTOR_FIELD_NAME = "my_vector" + + # Create index + client.ft(INDEX_NAME).create_index( + ( + VectorField( + VECTOR_FIELD_NAME, + "FLAT", + { + "TYPE": "FLOAT32", + "DIM": VECTOR_DIMENSIONS, + "DISTANCE_METRIC": "COSINE", + }, + ), + ), + definition=IndexDefinition(prefix=[DOC_PREFIX], index_type=IndexType.HASH), + ) + + # Add a document with a vector value + vector_data = [0.1, 0.2, 0.3, 0.4] + client.hset( + f"{DOC_PREFIX}1", + mapping={VECTOR_FIELD_NAME: np.array(vector_data, dtype=np.float32).tobytes()}, + ) + + # Perform a search to retrieve the document + query = Query("*").return_fields(VECTOR_FIELD_NAME).dialect(2) + res = client.ft(INDEX_NAME).search(query) + + # Assert that the document is retrieved and the vector matches the original data + assert res.total == 1 + assert res.docs[0].id == f"{DOC_PREFIX}1" + retrieved_vector_data = np.frombuffer( + res.docs[0].__dict__[VECTOR_FIELD_NAME], dtype=np.float32 + ) + assert np.allclose(retrieved_vector_data, vector_data)