Skip to content

Commit

Permalink
Fix gRPC conversion for sparse search batch (#484)
Browse files Browse the repository at this point in the history
* Fix gRPC convertion for sparse search batch

* fix conversion & add test
  • Loading branch information
agourlay authored and joein committed Feb 8, 2024
1 parent 5e564e6 commit 6df4bd0
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 4 deletions.
11 changes: 7 additions & 4 deletions qdrant_client/conversions/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1955,23 +1955,26 @@ def convert_batch_vector_struct(
@classmethod
def convert_named_vector_struct(
cls, model: rest.NamedVectorStruct
) -> Tuple[List[float], Optional[str]]:
) -> Tuple[List[float], Optional[grpc.SparseIndices], Optional[str]]:
if isinstance(model, list):
return model, None
return model, None, None
elif isinstance(model, rest.NamedVector):
return model.vector, model.name
return model.vector, None, model.name
elif isinstance(model, rest.NamedSparseVector):
return model.vector.values, grpc.SparseIndices(data=model.vector.indices), model.name
else:
raise ValueError(f"invalid NamedVectorStruct model: {model}") # pragma: no cover

@classmethod
def convert_search_request(
cls, model: rest.SearchRequest, collection_name: str
) -> grpc.SearchPoints:
vector, name = cls.convert_named_vector_struct(model.vector)
vector, sparse_indices, name = cls.convert_named_vector_struct(model.vector)

return grpc.SearchPoints(
collection_name=collection_name,
vector=vector,
sparse_indices=sparse_indices,
filter=cls.convert_filter(model.filter) if model.filter is not None else None,
limit=model.limit,
with_payload=cls.convert_with_payload_interface(model.with_payload)
Expand Down
86 changes: 86 additions & 0 deletions tests/test_qdrant_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1263,6 +1263,92 @@ def test_sparse_vectors(prefer_grpc):
assert result[1].vector["text"].values == [1.0, 2.0, 3.0]


@pytest.mark.parametrize("prefer_grpc", [False, True])
def test_sparse_vectors_batch(prefer_grpc):
version = os.getenv("QDRANT_VERSION")
if version is not None and version < "v1.7.0":
pytest.skip("Sparse vectors are supported since v1.7.0")

client = QdrantClient(prefer_grpc=prefer_grpc, timeout=TIMEOUT)

client.recreate_collection(
collection_name=COLLECTION_NAME,
vectors_config={},
sparse_vectors_config={
"text": models.SparseVectorParams(
index=models.SparseIndexParams(
on_disk=False,
full_scan_threshold=100,
)
)
},
)

client.upsert(
collection_name=COLLECTION_NAME,
points=[
models.PointStruct(
id=1,
vector={
"text": models.SparseVector(
indices=[1, 2, 3],
values=[1.0, 2.0, 3.0],
)
},
),
models.PointStruct(
id=2,
vector={
"text": models.SparseVector(
indices=[3, 4, 5],
values=[1.0, 2.0, 3.0],
)
},
),
models.PointStruct(
id=3,
vector={
"text": models.SparseVector(
indices=[5, 6, 7],
values=[1.0, 2.0, 3.0],
)
},
),
],
)

request = models.SearchRequest(
vector=models.NamedSparseVector(
name="text",
vector=models.SparseVector(
indices=[1, 7],
values=[2.0, 1.0],
),
),
limit=3,
with_vector=["text"],
)

results = client.search_batch(
collection_name=COLLECTION_NAME,
requests=[request],
)

result = results[0]

assert len(result) == 2
assert result[0].id == 3
assert result[1].id == 1

assert result[0].score == 3.0
assert result[1].score == 2.0

assert result[0].vector["text"].indices == [5, 6, 7]
assert result[0].vector["text"].values == [1.0, 2.0, 3.0]
assert result[1].vector["text"].indices == [1, 2, 3]
assert result[1].vector["text"].values == [1.0, 2.0, 3.0]


@pytest.mark.parametrize("prefer_grpc", [False, True])
def test_vector_update(prefer_grpc):
client = QdrantClient(prefer_grpc=prefer_grpc, timeout=TIMEOUT)
Expand Down

0 comments on commit 6df4bd0

Please sign in to comment.