From 834bffe6b72d2802241e800b5b6a4396696b6012 Mon Sep 17 00:00:00 2001 From: Girish Chandrashekar Date: Thu, 3 Nov 2022 15:37:19 +0100 Subject: [PATCH] fix: apply review suggestions --- jina/serve/networking.py | 2 +- jina/serve/runtimes/head/__init__.py | 54 ++++++++++++++++------------ 2 files changed, 33 insertions(+), 23 deletions(-) diff --git a/jina/serve/networking.py b/jina/serve/networking.py index 717e79fa33efc..a1496974d2e02 100644 --- a/jina/serve/networking.py +++ b/jina/serve/networking.py @@ -972,7 +972,7 @@ def _send_requests( metadata: Optional[Dict[str, str]] = None, timeout: Optional[float] = None, retries: Optional[int] = -1, - ) -> asyncio.Task: + ) -> asyncio.Task[Union[Tuple, AioRpcError, InternalNetworkError]]: # this wraps the awaitable object from grpc as a coroutine so it can be used as a task # the grpc call function is not a coroutine but some _AioCall diff --git a/jina/serve/runtimes/head/__init__.py b/jina/serve/runtimes/head/__init__.py index c95b9c277ca79..224c58336cb52 100644 --- a/jina/serve/runtimes/head/__init__.py +++ b/jina/serve/runtimes/head/__init__.py @@ -303,6 +303,32 @@ async def endpoint_discovery(self, empty, context) -> jina_pb2.EndpointsProto: return response + async def _gather_worker_tasks(self, requests, endpoint): + worker_send_tasks = self.connection_pool.send_requests( + requests=requests, + deployment=self._deployment_name, + polling_type=self._polling[endpoint], + timeout=self.timeout_send, + retries=self._retries, + ) + + all_worker_results = await asyncio.gather(*worker_send_tasks) + worker_results = list( + filter(lambda x: isinstance(x, Tuple), all_worker_results) + ) + exceptions = list( + filter( + lambda x: isinstance(x, (AioRpcError, InternalNetworkError)), + all_worker_results, + ) + ) + total_shards = len(worker_send_tasks) + failed_shards = len(exceptions) + if failed_shards: + self.logger.warning(f'{failed_shards} shards out of {total_shards} failed.') + + return worker_results, exceptions, total_shards, failed_shards + async def _handle_data_request( self, requests: List[DataRequest], endpoint: Optional[str] ) -> Tuple[DataRequest, Dict]: @@ -325,28 +351,12 @@ async def _handle_data_request( (response, uses_before_metadata) = uses_before_result requests = [response] - worker_send_tasks = self.connection_pool.send_requests( - requests=requests, - deployment=self._deployment_name, - polling_type=self._polling[endpoint], - timeout=self.timeout_send, - retries=self._retries, - ) - total_shards = len(worker_send_tasks) - - all_worker_results = await asyncio.gather(*worker_send_tasks) - worker_results = list( - filter(lambda x: isinstance(x, Tuple), all_worker_results) - ) - exceptions = list( - filter( - lambda x: isinstance(x, (AioRpcError, InternalNetworkError)), - all_worker_results, - ) - ) - failed_shards = len(exceptions) - if failed_shards: - self.logger.warning(f'{failed_shards} shards out of {total_shards} failed.') + ( + worker_results, + exceptions, + total_shards, + failed_shards, + ) = await self._gather_worker_tasks(requests, endpoint) if len(worker_results) == 0: if exceptions: