Skip to content

Commit

Permalink
Decode search results at field level
Browse files Browse the repository at this point in the history
  • Loading branch information
uglide committed Jul 9, 2024
1 parent 0be67bf commit 7ca2f29
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 23 deletions.
1 change: 1 addition & 0 deletions dev_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ urllib3<2
uvloop
vulture>=2.3.0
wheel>=0.30.0
numpy>=1.24.0
5 changes: 3 additions & 2 deletions docs/examples/search_vector_similarity_examples.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions redis/commands/search/_util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
def to_string(s):
def to_string(s, encoding: str = "utf-8"):
if isinstance(s, str):
return s
elif isinstance(s, bytes):
return s.decode("utf-8", "ignore")
return s.decode(encoding, "ignore")
else:
return s # Not a string we care about
1 change: 1 addition & 0 deletions redis/commands/search/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def _parse_search(self, res, **kwargs):
duration=kwargs["duration"],
has_payload=kwargs["query"]._with_payloads,
with_scores=kwargs["query"]._with_scores,
field_encodings=kwargs["query"]._return_fields_decode_as,
)

def _parse_aggregate(self, res, **kwargs):
Expand Down
23 changes: 19 additions & 4 deletions redis/commands/search/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(self, query_string: str) -> None:
self._in_order: bool = False
self._sortby: Optional[SortbyField] = None
self._return_fields: List = []
self._return_fields_decode_as: dict = {}
self._summarize_fields: List = []
self._highlight_fields: List = []
self._language: Optional[str] = None
Expand All @@ -53,13 +54,27 @@ def limit_ids(self, *ids) -> "Query":

def return_fields(self, *fields) -> "Query":
"""Add fields to return fields."""
self._return_fields += fields
for field in fields:
self.return_field(field)
return self

def return_field(self, field: str, as_field: Optional[str] = None) -> "Query":
"""Add field to return fields (Optional: add 'AS' name
to the field)."""
def return_field(
self,
field: str,
as_field: Optional[str] = None,
decode_field: Optional[bool] = True,
encoding: Optional[str] = "utf8",
) -> "Query":
"""
Add a field to the list of fields to return.
- **field**: The field to include in query results
- **as_field**: The alias for the field
- **decode_field**: Whether to decode the field from bytes to string
- **encoding**: The encoding to use when decoding the field
"""
self._return_fields.append(field)
self._return_fields_decode_as[field] = encoding if decode_field else None
if as_field is not None:
self._return_fields += ("AS", as_field)
return self
Expand Down
44 changes: 29 additions & 15 deletions redis/commands/search/result.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

from ._util import to_string
from .document import Document

Expand All @@ -9,11 +11,19 @@ class Result:
"""

def __init__(
self, res, hascontent, duration=0, has_payload=False, with_scores=False
self,
res,
hascontent,
duration=0,
has_payload=False,
with_scores=False,
field_encodings: Optional[dict] = None,
):
"""
- **snippets**: An optional dictionary of the form
{field: snippet_size} for snippet formatting
- duration: the execution time of the query
- has_payload: whether the query has payloads
- with_scores: whether the query has scores
- field_encodings: a dictionary of field encodings if any is provided
"""

self.total = res[0]
Expand All @@ -39,18 +49,22 @@ def __init__(

fields = {}
if hascontent and res[i + fields_offset] is not None:
fields = (
dict(
dict(
zip(
map(to_string, res[i + fields_offset][::2]),
map(to_string, res[i + fields_offset][1::2]),
)
)
)
if hascontent
else {}
)
keys = map(to_string, res[i + fields_offset][::2])
values = res[i + fields_offset][1::2]

for key, value in zip(keys, values):
if field_encodings is None or key not in field_encodings:
fields[key] = to_string(value)
continue

encoding = field_encodings[key]

# If the encoding is None, we don't need to decode the value
if encoding is None:
fields[key] = value
else:
fields[key] = to_string(value, encoding=encoding)

try:
del fields["id"]
except KeyError:
Expand Down
63 changes: 63 additions & 0 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time
from io import TextIOWrapper

import numpy as np
import pytest
import redis
import redis.commands.search
Expand Down Expand Up @@ -113,6 +114,13 @@ def client(request, stack_url):
return r


@pytest.fixture
def binary_client(request, stack_url):
r = _get_client(redis.Redis, request, decode_responses=False, from_url=stack_url)
r.flushdb()
return r


@pytest.mark.redismod
def test_client(client):
num_docs = 500
Expand Down Expand Up @@ -1705,6 +1713,61 @@ def test_search_return_fields(client):
assert "telmatosaurus" == total["results"][0]["extra_attributes"]["txt"]


@pytest.mark.redismod
def test_binary_and_text_fields(binary_client):
assert (
binary_client.get_connection_kwargs()["decode_responses"] is False
), "This feature is only available when decode_responses is False"

fake_vec = np.array([0.1, 0.2, 0.3, 0.4], dtype=np.float32)

index_name = "mixed_index"
mixed_data = {"first_name": "🐍python", "vector_emb": fake_vec.tobytes()}
binary_client.hset(f"{index_name}:1", mapping=mixed_data)

schema = (
TagField("first_name"),
VectorField(
"embeddings_bio",
algorithm="HNSW",
attributes={
"TYPE": "FLOAT32",
"DIM": 4,
"DISTANCE_METRIC": "COSINE",
},
),
)

binary_client.ft(index_name).create_index(
fields=schema,
definition=IndexDefinition(
prefix=[f"{index_name}:"], index_type=IndexType.HASH
),
)

bytes_person_1 = binary_client.hget(f"{index_name}:1", "vector_emb")
decoded_vec_from_hash = np.frombuffer(bytes_person_1, dtype=np.float32)
assert np.array_equal(decoded_vec_from_hash, fake_vec), "The vectors are not equal"

query = (
Query("*")
.return_field("vector_emb", decode_field=False)
.return_field("first_name", decode_field=True)
)
docs = binary_client.ft(index_name).search(query=query, query_params={}).docs
decoded_vec_from_search_results = np.frombuffer(
docs[0]["vector_emb"], dtype=np.float32
)

assert np.array_equal(
decoded_vec_from_search_results, fake_vec
), "The vectors are not equal"

assert (
docs[0]["first_name"] == mixed_data["first_name"]
), "The first is not decoded correctly"


@pytest.mark.redismod
def test_synupdate(client):
definition = IndexDefinition(index_type=IndexType.HASH)
Expand Down

0 comments on commit 7ca2f29

Please sign in to comment.