Skip to content

Commit

Permalink
fix: apply review suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
Girish Chandrashekar committed Nov 3, 2022
1 parent 30fd6d5 commit 834bffe
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 23 deletions.
2 changes: 1 addition & 1 deletion jina/serve/networking.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,7 +972,7 @@ def _send_requests(
metadata: Optional[Dict[str, str]] = None,
timeout: Optional[float] = None,
retries: Optional[int] = -1,
) -> asyncio.Task:
) -> asyncio.Task[Union[Tuple, AioRpcError, InternalNetworkError]]:
# this wraps the awaitable object from grpc as a coroutine so it can be used as a task
# the grpc call function is not a coroutine but some _AioCall

Expand Down
54 changes: 32 additions & 22 deletions jina/serve/runtimes/head/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,32 @@ async def endpoint_discovery(self, empty, context) -> jina_pb2.EndpointsProto:

return response

async def _gather_worker_tasks(self, requests, endpoint):
worker_send_tasks = self.connection_pool.send_requests(
requests=requests,
deployment=self._deployment_name,
polling_type=self._polling[endpoint],
timeout=self.timeout_send,
retries=self._retries,
)

all_worker_results = await asyncio.gather(*worker_send_tasks)
worker_results = list(
filter(lambda x: isinstance(x, Tuple), all_worker_results)
)
exceptions = list(
filter(
lambda x: isinstance(x, (AioRpcError, InternalNetworkError)),
all_worker_results,
)
)
total_shards = len(worker_send_tasks)
failed_shards = len(exceptions)
if failed_shards:
self.logger.warning(f'{failed_shards} shards out of {total_shards} failed.')

return worker_results, exceptions, total_shards, failed_shards

async def _handle_data_request(
self, requests: List[DataRequest], endpoint: Optional[str]
) -> Tuple[DataRequest, Dict]:
Expand All @@ -325,28 +351,12 @@ async def _handle_data_request(
(response, uses_before_metadata) = uses_before_result
requests = [response]

worker_send_tasks = self.connection_pool.send_requests(
requests=requests,
deployment=self._deployment_name,
polling_type=self._polling[endpoint],
timeout=self.timeout_send,
retries=self._retries,
)
total_shards = len(worker_send_tasks)

all_worker_results = await asyncio.gather(*worker_send_tasks)
worker_results = list(
filter(lambda x: isinstance(x, Tuple), all_worker_results)
)
exceptions = list(
filter(
lambda x: isinstance(x, (AioRpcError, InternalNetworkError)),
all_worker_results,
)
)
failed_shards = len(exceptions)
if failed_shards:
self.logger.warning(f'{failed_shards} shards out of {total_shards} failed.')
(
worker_results,
exceptions,
total_shards,
failed_shards,
) = await self._gather_worker_tasks(requests, endpoint)

if len(worker_results) == 0:
if exceptions:
Expand Down

0 comments on commit 834bffe

Please sign in to comment.