Skip to content

Commit

Permalink
enhance: support milvus-client iterator (#2461)
Browse files Browse the repository at this point in the history
related: #2464

Signed-off-by: MrPresent-Han <[email protected]>
Co-authored-by: MrPresent-Han <[email protected]>
  • Loading branch information
MrPresent-Han and MrPresent-Han authored Dec 18, 2024
1 parent ca9d40d commit cd4aab7
Show file tree
Hide file tree
Showing 4 changed files with 247 additions and 0 deletions.
99 changes: 99 additions & 0 deletions examples/iterator/iterator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from pymilvus.milvus_client.milvus_client import MilvusClient
from pymilvus import (
FieldSchema, CollectionSchema, DataType,
)
import numpy as np

collection_name = "test_milvus_client_iterator"
prepare_new_data = True
clean_exist = True

USER_ID = "id"
AGE = "age"
DEPOSIT = "deposit"
PICTURE = "picture"
DIM = 8
NUM_ENTITIES = 10000
rng = np.random.default_rng(seed=19530)


def test_query_iterator(milvus_client: MilvusClient):
# test query iterator
expr = f"10 <= {AGE} <= 25"
output_fields = [USER_ID, AGE]
queryIt = milvus_client.query_iterator(collection_name, filter=expr, batch_size=50, output_fields=output_fields)
page_idx = 0
while True:
res = queryIt.next()
if len(res) == 0:
print("query iteration finished, close")
queryIt.close()
break
for i in range(len(res)):
print(res[i])
page_idx += 1
print(f"page{page_idx}-------------------------")

def test_search_iterator(milvus_client: MilvusClient):
vector_to_search = rng.random((1, DIM), np.float32)
search_iterator = milvus_client.search_iterator(collection_name, data=vector_to_search, batch_size=100, anns_field=PICTURE)

page_idx = 0
while True:
res = search_iterator.next()
if len(res) == 0:
print("query iteration finished, close")
search_iterator.close()
break
for i in range(len(res)):
print(res[i])
page_idx += 1
print(f"page{page_idx}-------------------------")


def main():
milvus_client = MilvusClient("http://localhost:19530")
if milvus_client.has_collection(collection_name) and clean_exist:
milvus_client.drop_collection(collection_name)
print(f"dropped existed collection{collection_name}")

if not milvus_client.has_collection(collection_name):
fields = [
FieldSchema(name=USER_ID, dtype=DataType.INT64, is_primary=True, auto_id=False),
FieldSchema(name=AGE, dtype=DataType.INT64),
FieldSchema(name=DEPOSIT, dtype=DataType.DOUBLE),
FieldSchema(name=PICTURE, dtype=DataType.FLOAT_VECTOR, dim=DIM)
]
schema = CollectionSchema(fields)
milvus_client.create_collection(collection_name, dimension=DIM, schema=schema)

if prepare_new_data:
entities = []
for i in range(NUM_ENTITIES):
entity = {
USER_ID: i,
AGE: (i % 100),
DEPOSIT: float(i),
PICTURE: rng.random((1, DIM))[0]
}
entities.append(entity)
milvus_client.insert(collection_name, entities)
milvus_client.flush(collection_name)
print(f"Finish flush collections:{collection_name}")

index_params = milvus_client.prepare_index_params()

index_params.add_index(
field_name=PICTURE,
index_type='IVF_FLAT',
metric_type='L2',
params={"nlist": 1024}
)
milvus_client.create_index(collection_name, index_params)
milvus_client.load_collection(collection_name)
test_query_iterator(milvus_client=milvus_client)
test_search_iterator(milvus_client=milvus_client)


if __name__ == '__main__':
main()
23 changes: 23 additions & 0 deletions pymilvus/client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,3 +375,26 @@ def is_scipy_sparse(cls, data: Any):
"csr_array",
"spmatrix",
]


def is_sparse_vector_type(data_type: DataType) -> bool:
return data_type == data_type.SPARSE_FLOAT_VECTOR


dense_vector_type_set = {DataType.FLOAT_VECTOR, DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR}


def is_dense_vector_type(data_type: DataType) -> bool:
return data_type in dense_vector_type_set


def is_float_vector_type(data_type: DataType):
return is_sparse_vector_type(data_type) or is_dense_vector_type(data_type)


def is_binary_vector_type(data_type: DataType):
return data_type == DataType.BINARY_VECTOR


def is_vector_type(data_type: DataType):
return is_float_vector_type(data_type) or is_binary_vector_type(data_type)
124 changes: 124 additions & 0 deletions pymilvus/milvus_client/milvus_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,19 @@
OmitZeroDict,
construct_cost_extra,
)
from pymilvus.client.utils import is_vector_type
from pymilvus.exceptions import (
DataTypeNotMatchException,
ErrorCode,
MilvusException,
ParamError,
PrimaryKeyException,
)
from pymilvus.orm import utility
from pymilvus.orm.collection import CollectionSchema
from pymilvus.orm.connections import connections
from pymilvus.orm.constants import FIELDS, METRIC_TYPE, TYPE, UNLIMITED
from pymilvus.orm.iterator import QueryIterator, SearchIterator
from pymilvus.orm.types import DataType

from .index import IndexParams
Expand Down Expand Up @@ -480,6 +484,126 @@ def query(

return res

def query_iterator(
self,
collection_name: str,
batch_size: Optional[int] = 1000,
limit: Optional[int] = UNLIMITED,
filter: Optional[str] = "",
output_fields: Optional[List[str]] = None,
partition_names: Optional[List[str]] = None,
timeout: Optional[float] = None,
**kwargs,
):
if filter is not None and not isinstance(filter, str):
raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(filter))

conn = self._get_connection()
# set up schema for iterator
try:
schema_dict = conn.describe_collection(collection_name, timeout=timeout, **kwargs)
except Exception as ex:
logger.error("Failed to describe collection: %s", collection_name)
raise ex from ex

return QueryIterator(
connection=conn,
collection_name=collection_name,
batch_size=batch_size,
limit=limit,
expr=filter,
output_fields=output_fields,
partition_names=partition_names,
schema=schema_dict,
timeout=timeout,
**kwargs,
)

def search_iterator(
self,
collection_name: str,
data: Union[List[list], list],
batch_size: Optional[int] = 1000,
filter: Optional[str] = None,
limit: Optional[int] = UNLIMITED,
output_fields: Optional[List[str]] = None,
search_params: Optional[dict] = None,
timeout: Optional[float] = None,
partition_names: Optional[List[str]] = None,
anns_field: Optional[str] = None,
round_decimal: int = -1,
**kwargs,
):
if filter is not None and not isinstance(filter, str):
raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(filter))

conn = self._get_connection()
# set up schema for iterator
try:
schema_dict = conn.describe_collection(collection_name, timeout=timeout, **kwargs)
except Exception as ex:
logger.error("Failed to describe collection: %s", collection_name)
raise ex from ex
# if anns_field is not provided
# if only one vector field, use to search
# if multiple vector fields, raise exception and abort
if anns_field is None or anns_field == "":
vec_field = None
fields = schema_dict[FIELDS]
vec_field_count = 0
for field in fields:
if is_vector_type(field[TYPE]):
vec_field_count += 1
vec_field = field
if vec_field is None:
raise MilvusException(
code=ErrorCode.UNEXPECTED_ERROR,
message="there should be at least one vector field in milvus collection",
)
if vec_field_count > 1:
raise MilvusException(
code=ErrorCode.UNEXPECTED_ERROR,
message="must specify anns_field when there are more than one vector field",
)
anns_field = vec_field["name"]
if anns_field is None or anns_field == "":
raise MilvusException(
code=ErrorCode.UNEXPECTED_ERROR,
message=f"cannot get anns_field name for search iterator, got:{anns_field}",
)
# set up metrics type for search_iterator which is mandatory
if search_params is None:
search_params = {}
if METRIC_TYPE not in search_params:
indexes = conn.list_indexes(collection_name)
for index in indexes:
if anns_field == index.index_name:
params = index.params
for param in params:
if param.key == METRIC_TYPE:
search_params[METRIC_TYPE] = param.value
if METRIC_TYPE not in search_params:
raise MilvusException(
ParamError, f"Cannot set up metrics type for anns_field:{anns_field}"
)

return SearchIterator(
connection=self._get_connection(),
collection_name=collection_name,
data=data,
ann_field=anns_field,
param=search_params,
batch_size=batch_size,
limit=limit,
expr=filter,
partition_names=partition_names,
output_fields=output_fields,
timeout=timeout,
round_decimal=round_decimal,
schema=schema_dict,
**kwargs,
)

def get(
self,
collection_name: str,
Expand Down
1 change: 1 addition & 0 deletions pymilvus/orm/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
MILVUS_LIMIT = "limit"
BATCH_SIZE = "batch_size"
ID = "id"
TYPE = "type"
METRIC_TYPE = "metric_type"
PARAMS = "params"
DISTANCE = "distance"
Expand Down

0 comments on commit cd4aab7

Please sign in to comment.