Skip to content

Commit

Permalink
Update pymilvus for memory replica (#942)
Browse files Browse the repository at this point in the history
Signed-off-by: XuanYang-cn <[email protected]>
  • Loading branch information
XuanYang-cn authored Apr 3, 2022
1 parent a1fbdfe commit dcb5224
Show file tree
Hide file tree
Showing 11 changed files with 725 additions and 242 deletions.
65 changes: 27 additions & 38 deletions pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,9 +663,9 @@ def wait_for_creating_index(self, collection_name, field_name, timeout=None, **k

@retry_on_rpc_failure(retry_times=10, wait=1)
@error_handler
def load_collection(self, collection_name, timeout=None, **kwargs):
def load_collection(self, collection_name, timeout=None, replica_number=1, **kwargs):
check_pass_param(collection_name=collection_name)
request = Prepare.load_collection("", collection_name)
request = Prepare.load_collection("", collection_name, replica_number)
rf = self._stub.LoadCollection.future(request, wait_for_ready=True, timeout=timeout)
response = rf.result()
if response.error_code != 0:
Expand All @@ -679,56 +679,45 @@ def load_collection(self, collection_name, timeout=None, **kwargs):
def load_collection_progress(self, collection_name, timeout=None):
""" Return loading progress of collection """

loaded_segments_nums = sum(info.num_rows for info in
self.get_query_segment_info(collection_name, timeout))

total_segments_nums = sum(info.num_rows for info in
self.get_persistent_segment_infos(collection_name, timeout))

progress = (loaded_segments_nums / total_segments_nums) * 100 if loaded_segments_nums < total_segments_nums else 100
progress = self.get_collection_loading_progress(collection_name, timeout)

return {'loading_progress': f"{progress:.0f}%"}

@retry_on_rpc_failure(retry_times=10, wait=1)
@error_handler
def wait_for_loading_collection(self, collection_name, timeout=None):
return self._wait_for_loading_collection_v2(collection_name, timeout)
return self._wait_for_loading_collection(collection_name, timeout)

# TODO seems not in use
def _wait_for_loading_collection_v1(self, collection_name, timeout=None):
""" Block until load collection complete. """
unloaded_segments = {info.segmentID: info.num_rows for info in
self.get_persistent_segment_infos(collection_name, timeout)}

while len(unloaded_segments) > 0:
time.sleep(0.5)

for info in self.get_query_segment_info(collection_name, timeout):
if 0 <= unloaded_segments.get(info.segmentID, -1) <= info.num_rows:
unloaded_segments.pop(info.segmentID)

def _wait_for_loading_collection_v2(self, collection_name, timeout=None):
""" Block until load collection complete. """
def get_collection_loading_progress(self, collection_name: str, timeout=None) -> int:
request = Prepare.show_collections_request([collection_name])
future = self._stub.ShowCollections.future(request, wait_for_ready=True, timeout=timeout)
response = future.result()

while True:
future = self._stub.ShowCollections.future(request, wait_for_ready=True, timeout=timeout)
response = future.result()
if response.status.error_code != 0:
raise BaseException(response.status.error_code, response.status.reason)

if response.status.error_code != 0:
raise BaseException(response.status.error_code, response.status.reason)
ol = len(response.collection_names)
pl = len(response.inMemory_percentages)

ol = len(response.collection_names)
pl = len(response.inMemory_percentages)
if ol != pl:
raise BaseException(ErrorCode.UnexpectedError,
f"len(collection_names) ({ol}) != len(inMemory_percentages) ({pl})")

if ol != pl:
raise BaseException(ErrorCode.UnexpectedError,
f"len(collection_names) ({ol}) != len(inMemory_percentages) ({pl})")
for i, coll_name in enumerate(response.collection_names):
if coll_name == collection_name:
return response.inMemory_percentages[i]

for i, coll_name in enumerate(response.collection_names):
if coll_name == collection_name and response.inMemory_percentages[i] == 100:
return
def _wait_for_loading_collection(self, collection_name, timeout=None):
""" Block until load collection complete. """
start = time.time()

def can_loop(t) -> bool:
return True if timeout is None else t > (start + timeout)

while can_loop(time.time()):
progress = self.get_collection_loading_progress(collection_name, timeout)
if progress >= 100:
return
time.sleep(DefaultConfigs.WaitTimeDurationWhenLoad)

@retry_on_rpc_failure(retry_times=10, wait=1)
Expand Down
5 changes: 3 additions & 2 deletions pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,8 +644,9 @@ def get_index_state_request(cls, collection_name, field_name, **kwargs):
index_name=kwargs.get("index_name", DefaultConfigs.IndexName))

@classmethod
def load_collection(cls, db_name, collection_name):
return milvus_types.LoadCollectionRequest(db_name=db_name, collection_name=collection_name)
def load_collection(cls, db_name, collection_name, replica_number):
return milvus_types.LoadCollectionRequest(db_name=db_name, collection_name=collection_name,
replica_number=replica_number)

@classmethod
def release_collection(cls, db_name, collection_name):
Expand Down
5 changes: 4 additions & 1 deletion pymilvus/client/stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def describe_collection(self, collection_name, timeout=None):
with self._connection() as handler:
return handler.describe_collection(collection_name, timeout)

def load_collection(self, collection_name, timeout=None, **kwargs):
def load_collection(self, collection_name, timeout=None, replica_number=1, **kwargs):
"""
Loads a specified collection from disk to memory.
Expand All @@ -175,6 +175,9 @@ def load_collection(self, collection_name, timeout=None, **kwargs):
is set to None, client waits until server response or error occur.
:type timeout: float
:param replica_number: Number of replication in memory to load
:type replica_number: int
:return: None
:rtype: NoneType
Expand Down
Loading

0 comments on commit dcb5224

Please sign in to comment.