Skip to content

Commit

Permalink
fix: check for result types after the task has completed its execution
Browse files Browse the repository at this point in the history
  • Loading branch information
Girish Chandrashekar committed Nov 7, 2022
1 parent 7500de0 commit 068ebb9
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 15 deletions.
7 changes: 3 additions & 4 deletions jina/serve/networking.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from grpc_health.v1 import health_pb2, health_pb2_grpc
from grpc_reflection.v1alpha.reflection_pb2 import ServerReflectionRequest
from grpc_reflection.v1alpha.reflection_pb2_grpc import ServerReflectionStub

from jina import __default_endpoint__
from jina.enums import PollingType
from jina.excepts import EstablishGrpcConnectionError, InternalNetworkError
Expand Down Expand Up @@ -852,10 +853,7 @@ def send_requests_once(
timeout=timeout,
retries=retries,
)
if isinstance(result, (AioRpcError, InternalNetworkError)):
raise result
else:
return result
return result
else:
self._logger.debug(
f'no available connections for deployment {deployment} and shard {shard_id}'
Expand Down Expand Up @@ -1015,6 +1013,7 @@ async def task_wrapper():
connection_list=connections,
)
if error:
print(f'--->returning error: {type(error)}')
return error

return asyncio.create_task(task_wrapper())
Expand Down
7 changes: 6 additions & 1 deletion jina/serve/runtimes/gateway/graph/topology_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import grpc.aio
from grpc.aio import AioRpcError

from jina import __default_endpoint__
from jina.excepts import InternalNetworkError
from jina.serve.networking import GrpcConnectionPool
Expand Down Expand Up @@ -159,7 +160,7 @@ async def _wait_previous_and_send(
return request, metadata
# otherwise, send to executor and get response
try:
resp, metadata = await connection_pool.send_requests_once(
result = await connection_pool.send_requests_once(
requests=self.parts_to_send,
deployment=self.name,
metadata=self._metadata,
Expand All @@ -168,6 +169,10 @@ async def _wait_previous_and_send(
timeout=self._timeout_send,
retries=self._retries,
)
if issubclass(type(result), BaseException):
raise result
else:
resp, metadata = result
if WorkerRequestHandler._KEY_RESULT in resp.parameters:
# Accumulate results from each Node and then add them to the original
self.result_in_params_returned = resp.parameters[
Expand Down
23 changes: 13 additions & 10 deletions jina/serve/runtimes/head/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from grpc.aio import AioRpcError
from grpc_health.v1 import health, health_pb2, health_pb2_grpc
from grpc_reflection.v1alpha import reflection

from jina.enums import PollingType
from jina.excepts import InternalNetworkError
from jina.helper import get_full_version
Expand Down Expand Up @@ -318,7 +319,7 @@ async def _gather_worker_tasks(self, requests, endpoint):
)
exceptions = list(
filter(
lambda x: isinstance(x, (AioRpcError, InternalNetworkError)),
lambda x: issubclass(type(x), BaseException),
all_worker_results,
)
)
Expand All @@ -338,16 +339,17 @@ async def _handle_data_request(

uses_before_metadata = None
if self.uses_before_address:
(
response,
uses_before_metadata,
) = await self.connection_pool.send_requests_once(
result = await self.connection_pool.send_requests_once(
requests,
deployment='uses_before',
timeout=self.timeout_send,
retries=self._retries,
)
requests = [response]
if issubclass(type(result), BaseException):
raise result
else:
response, uses_before_metadata = result
requests = [response]

(
worker_results,
Expand All @@ -369,15 +371,16 @@ async def _handle_data_request(
response_request = worker_results[0]
uses_after_metadata = None
if self.uses_after_address:
(
response_request,
uses_after_metadata,
) = await self.connection_pool.send_requests_once(
result = await self.connection_pool.send_requests_once(
worker_results,
deployment='uses_after',
timeout=self.timeout_send,
retries=self._retries,
)
if issubclass(type(result), BaseException):
raise result
else:
response_request, uses_after_metadata = result
elif len(worker_results) > 1 and self._reduce:
response_request = WorkerRequestHandler.reduce_requests(worker_results)
elif len(worker_results) > 1 and not self._reduce:
Expand Down

0 comments on commit 068ebb9

Please sign in to comment.