Skip to content

Commit

Permalink
[2.4] support the report value in the dml and dql request
Browse files Browse the repository at this point in the history
Signed-off-by: SimFG <[email protected]>
  • Loading branch information
SimFG committed May 15, 2024
1 parent 5b59272 commit 1177c7c
Show file tree
Hide file tree
Showing 8 changed files with 344 additions and 15 deletions.
187 changes: 187 additions & 0 deletions examples/hello_cost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
# hello_milvus.py demonstrates the basic operations of PyMilvus, a Python SDK of Milvus.
# 1. connect to Milvus
# 2. create collection
# 3. insert data
# 4. create index
# 5. search, query, and hybrid search on entities
# 6. delete entities by PK
# 7. drop collection
import time

import numpy as np
from pymilvus import (
connections,
utility,
FieldSchema, CollectionSchema, DataType,
Collection,
)

fmt = "\n=== {:30} ===\n"
search_latency_fmt = "search latency = {:.4f}s"
num_entities, dim = 10, 8

#################################################################################
# 1. connect to Milvus
# Add a new connection alias `default` for Milvus server in `localhost:19530`
# Actually the "default" alias is a buildin in PyMilvus.
# If the address of Milvus is the same as `localhost:19530`, you can omit all
# parameters and call the method as: `connections.connect()`.
#
# Note: the `using` parameter of the following methods is default to "default".
print(fmt.format("start connecting to Milvus"))
connections.connect("default", host="localhost", port="19530")

collection_name = "hello_cost"
has = utility.has_collection(collection_name)
print(f"Does collection {collection_name} exist in Milvus: {has}")

#################################################################################
# 2. create collection
# We're going to create a collection with 3 fields.
# +-+------------+------------+------------------+------------------------------+
# | | field name | field type | other attributes | field description |
# +-+------------+------------+------------------+------------------------------+
# |1| "pk" | VarChar | is_primary=True | "primary field" |
# | | | | auto_id=False | |
# +-+------------+------------+------------------+------------------------------+
# |2| "random" | Double | | "a double field" |
# +-+------------+------------+------------------+------------------------------+
# |3|"embeddings"| FloatVector| dim=8 | "float vector with dim 8" |
# +-+------------+------------+------------------+------------------------------+
fields = [
FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, auto_id=False, max_length=100),
FieldSchema(name="random", dtype=DataType.DOUBLE),
FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=dim)
]

schema = CollectionSchema(fields, f"{collection_name} is the simplest demo to introduce the APIs")

print(fmt.format(f"Create collection `{collection_name}`"))
hello_milvus = Collection(collection_name, schema, consistency_level="Strong")

################################################################################
# 3. insert data
# We are going to insert 3000 rows of data into `hello_milvus`
# Data to be inserted must be organized in fields.
#
# The insert() method returns:
# - either automatically generated primary keys by Milvus if auto_id=True in the schema;
# - or the existing primary key field from the entities if auto_id=False in the schema.

print(fmt.format("Start inserting entities"))
rng = np.random.default_rng(seed=19530)
entities = [
# provide the pk field because `auto_id` is set to False
[str(i) for i in range(num_entities)],
rng.random(num_entities).tolist(), # field random, only supports list
rng.random((num_entities, dim)), # field embeddings, supports numpy.ndarray and list
]

insert_result = hello_milvus.insert(entities)
# OUTPUT:
# insert result: (insert count: 10, delete count: 0, upsert count: 0, timestamp: 449296288881311748, success count: 10, err count: 0, cost: 1);
# insert cost: 1
print(f"insert result: {insert_result};\ninsert cost: {insert_result.cost}")

hello_milvus.flush()
print(f"Number of entities in Milvus: {hello_milvus.num_entities}") # check the num_entities

################################################################################
# 4. create index
# We are going to create an IVF_FLAT index for hello_milvus collection.
# create_index() can only be applied to `FloatVector` and `BinaryVector` fields.
print(fmt.format("Start Creating index IVF_FLAT"))
index = {
"index_type": "IVF_FLAT",
"metric_type": "L2",
"params": {"nlist": 128},
}

hello_milvus.create_index("embeddings", index)

################################################################################
# 5. search, query, and hybrid search
# After data were inserted into Milvus and indexed, you can perform:
# - search based on vector similarity
# - query based on scalar filtering(boolean, int, etc.)
# - hybrid search based on vector similarity and scalar filtering.
#

# Before conducting a search or a query, you need to load the data in `hello_milvus` into memory.
print(fmt.format("Start loading"))
hello_milvus.load()

# -----------------------------------------------------------------------------
# search based on vector similarity
print(fmt.format("Start searching based on vector similarity"))
vectors_to_search = entities[-1][-2:]
search_params = {
"metric_type": "L2",
"params": {"nprobe": 10},
}

start_time = time.time()
result = hello_milvus.search(vectors_to_search, "embeddings", search_params, limit=3, output_fields=["random"])
end_time = time.time()

# OUTPUT:
# search result: data: ['["id: 8, distance: 0.0, entity: {\'random\': 0.9007387227368949}", "id: 0, distance: 0.49515748023986816, entity: {\'random\': 0.6378742006852851}", "id: 2, distance: 0.5305156707763672, entity: {\'random\': 0.1321158395732429}"]', '["id: 9, distance: 0.0, entity: {\'random\': 0.4494463384561439}", "id: 8, distance: 0.558194100856781, entity: {\'random\': 0.9007387227368949}", "id: 2, distance: 0.7718868255615234, entity: {\'random\': 0.1321158395732429}"]'], cost: 21;
# search cost: 21
print(f"search result: {result};\nsearch cost: {result.cost}")
print(search_latency_fmt.format(end_time - start_time))

# -----------------------------------------------------------------------------
# query based on scalar filtering(boolean, int, etc.)
print(fmt.format("Start querying with `random > 0.5`"))

start_time = time.time()
result = hello_milvus.query(expr="random > 0.5", output_fields=["random", "embeddings"])
end_time = time.time()

# OUTPUT:
# query result: data: ["{'random': 0.6378742006852851, 'embeddings': [0.18477614, 0.42930314, 0.40345728, 0.3957196, 0.6963897, 0.24356908, 0.42512414, 0.5724385], 'pk': '0'}", "{'random': 0.744296470467782, 'embeddings': [0.8349225, 0.6614872, 0.98359716, 0.15854438, 0.30939594, 0.23553558, 0.1950739, 0.80361205], 'pk': '4'}", "{'random': 0.6025374094941409, 'embeddings': [0.36677808, 0.218786, 0.25240582, 0.82230526, 0.21011819, 0.16813536, 0.8129038, 0.74800706], 'pk': '7'}", "{'random': 0.9007387227368949, 'embeddings': [0.27464902, 0.07500089, 0.57728964, 0.6654878, 0.8698446, 0.3814792, 0.8825416, 0.58730817], 'pk': '8'}"], extra_info: {'cost': '21'};
# query cost: 21
print(f"query result: {result};\nquery cost: {result.extra['cost']}")
print(search_latency_fmt.format(end_time - start_time))


# -----------------------------------------------------------------------------
# hybrid search
print(fmt.format("Start hybrid searching with `random > 0.5`"))

start_time = time.time()
result = hello_milvus.search(vectors_to_search, "embeddings", search_params, limit=3, expr="random > 0.5", output_fields=["random"])
end_time = time.time()

# OUTPUT:
# search result: data: ['["id: 8, distance: 0.0, entity: {\'random\': 0.9007387227368949}", "id: 0, distance: 0.49515748023986816, entity: {\'random\': 0.6378742006852851}", "id: 7, distance: 0.670731246471405, entity: {\'random\': 0.6025374094941409}"]', '["id: 8, distance: 0.558194100856781, entity: {\'random\': 0.9007387227368949}", "id: 0, distance: 1.0780366659164429, entity: {\'random\': 0.6378742006852851}", "id: 7, distance: 1.1083570718765259, entity: {\'random\': 0.6025374094941409}"]'], cost: 21;
# search cost: 21
print(f"search result: {result};\nsearch cost: {result.cost}")
print(search_latency_fmt.format(end_time - start_time))

###############################################################################
# 6. delete entities by PK
# You can delete entities by their PK values using boolean expressions.
ids = insert_result.primary_keys

expr = f'pk in ["{ids[0]}" , "{ids[1]}"]'
print(fmt.format(f"Start deleting with expr `{expr}`"))

result = hello_milvus.query(expr=expr, output_fields=["random", "embeddings"])
print(f"query before delete by expr=`{expr}` -> result: \n-{result[0]}\n-{result[1]}\n")

delete_result = hello_milvus.delete(expr)
# OUTPUT:
# delete result: (insert count: 0, delete count: 2, upsert count: 0, timestamp: 0, success count: 0, err count: 0, cost: 2);
# delete cost: 2
print(f"delete result: {delete_result};\ndelete cost: {delete_result.cost}")

result = hello_milvus.query(expr=expr, output_fields=["random", "embeddings"])
print(f"query after delete by expr=`{expr}` -> result: {result}\n")


###############################################################################
# 7. drop collection
# Finally, drop the hello_milvus collection
print(fmt.format(f"Drop collection `{collection_name}`"))
utility.drop_collection(collection_name)
74 changes: 74 additions & 0 deletions examples/milvus_client/simple_cost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import time
import numpy as np
from pymilvus import (
MilvusClient,
)

fmt = "\n=== {:30} ===\n"
dim = 8
collection_name = "hello_client_cost"
# milvus_client = MilvusClient("http://localhost:19530")
milvus_client = MilvusClient(uri="https://in01-20fa6a32462c074.aws-us-west-2.vectordb-uat3.zillizcloud.com:19541",
token="root:j6|y3/g$5Lq,a[TJ^ckphSMs{-F[&Jl)")

has_collection = milvus_client.has_collection(collection_name, timeout=5)
if has_collection:
milvus_client.drop_collection(collection_name)
milvus_client.create_collection(collection_name, dim, consistency_level="Strong", metric_type="L2")

print(fmt.format(" all collections "))
print(milvus_client.list_collections())

print(fmt.format(f"schema of collection {collection_name}"))
print(milvus_client.describe_collection(collection_name))

rng = np.random.default_rng(seed=19530)
rows = [
{"id": 1, "vector": rng.random((1, dim))[0], "a": 100},
{"id": 2, "vector": rng.random((1, dim))[0], "b": 200},
{"id": 3, "vector": rng.random((1, dim))[0], "c": 300},
{"id": 4, "vector": rng.random((1, dim))[0], "d": 400},
{"id": 5, "vector": rng.random((1, dim))[0], "e": 500},
{"id": 6, "vector": rng.random((1, dim))[0], "f": 600},
]

print(fmt.format("Start inserting entities"))
insert_result = milvus_client.insert(collection_name, rows, progress_bar=True)
print(fmt.format("Inserting entities done"))
# OUTPUT:
# insert result: {'insert_count': 6, 'ids': [1, 2, 3, 4, 5, 6], 'cost': '1'};
# insert cost: 1
print(f"insert result: {insert_result};\ninsert cost: {insert_result['cost']}")

print(fmt.format("Start query by specifying primary keys"))
query_results = milvus_client.query(collection_name, ids=[2])
# OUTPUT:
# query result: data: ["{'id': 2, 'vector': [0.9007387, 0.44944635, 0.18477614, 0.42930314, 0.40345728, 0.3957196, 0.6963897, 0.24356908], 'b': 200}"], extra_info: {'cost': '21'}
# query cost: 21
print(f"query result: {query_results}\nquery cost: {query_results.extra['cost']}")

upsert_ret = milvus_client.upsert(collection_name, {"id": 2 , "vector": rng.random((1, dim))[0], "g": 100})
# OUTPUT:
# upsert result: {'upsert_count': 1, 'cost': '2'}
# upsert cost: 2
print(f"upsert result: {upsert_ret}\nupsert cost: {upsert_ret['cost']}")

print(fmt.format("Start query by specifying primary keys"))
query_results = milvus_client.query(collection_name, ids=[2])
print(f"query result: {query_results}\nquery cost: {query_results.extra['cost']}")

print(f"start to delete by specifying filter in collection {collection_name}")
delete_result = milvus_client.delete(collection_name, ids=[6])
# OUTPUT:
# delete result: {'delete_count': 1, 'cost': '1'}
# delete cost: 1
print(f"delete result: {delete_result}\ndelete cost: {delete_result['cost']}")

rng = np.random.default_rng(seed=19530)
vectors_to_search = rng.random((1, dim))

print(fmt.format(f"Start search with retrieve serveral fields."))
result = milvus_client.search(collection_name, vectors_to_search, limit=3, output_fields=["pk", "a", "b"])
print(f"search result: {result}\nsearch cost: {result.extra['cost']}")

milvus_client.drop_collection(collection_name)
25 changes: 21 additions & 4 deletions pymilvus/client/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import ujson

from pymilvus.exceptions import DataTypeNotMatchException, ExceptionsMessage, MilvusException
from pymilvus.grpc_gen import schema_pb2
from pymilvus.grpc_gen import common_pb2, schema_pb2
from pymilvus.settings import Config

from . import entity_helper, utils
Expand Down Expand Up @@ -195,6 +195,7 @@ def __init__(self, raw: Any):
self._timestamp = 0
self._succ_index = []
self._err_index = []
self._cost = 0

self._pack(raw)

Expand Down Expand Up @@ -234,10 +235,16 @@ def succ_index(self):
def err_index(self):
return self._err_index

# The unit of this cost is vcu, similar to token
@property
def cost(self):
return self._cost

def __str__(self):
return (
f"(insert count: {self._insert_cnt}, delete count: {self._delete_cnt}, upsert count: {self._upsert_cnt}, "
f"timestamp: {self._timestamp}, success count: {self.succ_count}, err count: {self.err_count})"
f"timestamp: {self._timestamp}, success count: {self.succ_count}, err count: {self.err_count}, "
f"cost: {self._cost})"
)

__repr__ = __str__
Expand All @@ -262,6 +269,9 @@ def _pack(self, raw: Any):
self._timestamp = raw.timestamp
self._succ_index = raw.succ_index
self._err_index = raw.err_index
self._cost = int(
raw.status.extra_info["report_value"] if raw.status and raw.status.extra_info else "0"
)


class SequenceIterator:
Expand Down Expand Up @@ -374,10 +384,17 @@ def __str__(self):
class SearchResult(list):
"""nq results: List[Hits]"""

def __init__(self, res: schema_pb2.SearchResultData, round_decimal: Optional[int] = None):
def __init__(
self,
res: schema_pb2.SearchResultData,
round_decimal: Optional[int] = None,
status: Optional[common_pb2.Status] = None,
):
self._nq = res.num_queries
all_topks = res.topks

self.cost = int(status.extra_info["report_value"] if status and status.extra_info else "0")

output_fields = res.output_fields
fields_data = res.fields_data

Expand Down Expand Up @@ -497,7 +514,7 @@ def __iter__(self) -> SequenceIterator:

def __str__(self) -> str:
"""Only print at most 10 query results"""
return str(list(map(str, self[:10])))
return f"data: {list(map(str, self[:10]))} {'...' if len(self) > 10 else ''}, cost: {self.cost}"

__repr__ = __str__

Expand Down
2 changes: 1 addition & 1 deletion pymilvus/client/asynch.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def exception(self):
class SearchFuture(Future):
def on_response(self, response: milvus_pb2.SearchResults):
check_status(response.status)
return SearchResult(response.results)
return SearchResult(response.results, status=response.status)


class MutationFuture(Future):
Expand Down
8 changes: 5 additions & 3 deletions pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
CompactionState,
DatabaseInfo,
DataType,
ExtraList,
GrantInfo,
Group,
IndexState,
Expand All @@ -57,6 +58,7 @@
State,
Status,
UserInfo,
get_cost_extra,
)
from .utils import (
check_invalid_binary_vector,
Expand Down Expand Up @@ -732,7 +734,7 @@ def _execute_search(
response = self._stub.Search(request, timeout=timeout)
check_status(response.status)
round_decimal = kwargs.get("round_decimal", -1)
return SearchResult(response.results, round_decimal)
return SearchResult(response.results, round_decimal, status=response.status)

except Exception as e:
if kwargs.get("_async", False):
Expand All @@ -751,7 +753,7 @@ def _execute_hybrid_search(
response = self._stub.HybridSearch(request, timeout=timeout)
check_status(response.status)
round_decimal = kwargs.get("round_decimal", -1)
return SearchResult(response.results, round_decimal)
return SearchResult(response.results, round_decimal, status=response.status)

except Exception as e:
if kwargs.get("_async", False):
Expand Down Expand Up @@ -1519,7 +1521,7 @@ def query(
response.fields_data, index, dynamic_fields
)
results.append(entity_row_data)
return results
return ExtraList(results, extra=get_cost_extra(response.status))

@retry_on_rpc_failure()
def load_balance(
Expand Down
Loading

0 comments on commit 1177c7c

Please sign in to comment.