Skip to content

Commit

Permalink
Fix vector type fields should not be encoded as strings (redis#2772)
Browse files Browse the repository at this point in the history
  • Loading branch information
bsbodden committed May 29, 2024
1 parent 0d47d65 commit 48609a8
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 13 deletions.
1 change: 1 addition & 0 deletions dev_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ ujson>=4.2.0
wheel>=0.30.0
urllib3<2
uvloop
numpy>=1.24.4
6 changes: 5 additions & 1 deletion redis/_parsers/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,9 @@ def decode(self, value, force=False):
if isinstance(value, memoryview):
value = value.tobytes()
if isinstance(value, bytes):
value = value.decode(self.encoding, self.encoding_errors)
try:
value = value.decode(self.encoding, self.encoding_errors)
except UnicodeDecodeError as e:
# Return the bytes unmodified
return value
return value
17 changes: 5 additions & 12 deletions redis/commands/search/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,11 @@ def __init__(

fields = {}
if hascontent and res[i + fields_offset] is not None:
fields = (
dict(
dict(
zip(
map(to_string, res[i + fields_offset][::2]),
map(to_string, res[i + fields_offset][1::2]),
)
)
)
if hascontent
else {}
)
for j in range(0, len(res[i + fields_offset]), 2):
key = to_string(res[i + fields_offset][j])
value = res[i + fields_offset][j + 1]
fields[key] = value

try:
del fields["id"]
except KeyError:
Expand Down
39 changes: 39 additions & 0 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import time
from io import TextIOWrapper

import numpy as np

import pytest
import redis
import redis.commands.search
Expand Down Expand Up @@ -2282,3 +2284,40 @@ def test_geoshape(client: redis.Redis):
assert result.docs[0]["id"] == "small"
result = client.ft().search(q2, query_params=qp2)
assert len(result.docs) == 2

@pytest.mark.redismod
def test_vector_storage_and_retrieval(client):
# Constants
INDEX_NAME = "vector_index"
DOC_PREFIX = "doc:"
VECTOR_DIMENSIONS = 4
VECTOR_FIELD_NAME = "my_vector"

# Create index
client.ft(INDEX_NAME).create_index((
VectorField(VECTOR_FIELD_NAME,
"FLAT", {
"TYPE": "FLOAT32",
"DIM": VECTOR_DIMENSIONS,
"DISTANCE_METRIC": "COSINE",
}
),
), definition=IndexDefinition(prefix=[DOC_PREFIX], index_type=IndexType.HASH))

# Add a document with a vector value
vector_data = [0.1, 0.2, 0.3, 0.4]
client.hset(f"{DOC_PREFIX}1", mapping={
VECTOR_FIELD_NAME: np.array(vector_data, dtype=np.float32).tobytes()
})

# Perform a search to retrieve the document
query = (
Query("*").return_fields(VECTOR_FIELD_NAME).dialect(2)
)
res = client.ft(INDEX_NAME).search(query)

# Assert that the document is retrieved and the vector matches the original data
assert res.total == 1
assert res.docs[0].id == f"{DOC_PREFIX}1"
retrieved_vector_data = np.frombuffer(res.docs[0].__dict__[VECTOR_FIELD_NAME], dtype=np.float32)
assert np.allclose(retrieved_vector_data, vector_data)

0 comments on commit 48609a8

Please sign in to comment.