Skip to content

Commit

Permalink
Merge branch 'master' into feat-RR-UUID
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM authored Nov 8, 2022
2 parents 11ebf3c + db1c406 commit 38bec69
Show file tree
Hide file tree
Showing 8 changed files with 193 additions and 112 deletions.
4 changes: 3 additions & 1 deletion jina/orchestrate/deployments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,7 +819,9 @@ def _set_pod_args(args: Namespace) -> Dict[int, List[Namespace]]:
for replica_id in range(replicas):
_args = copy.deepcopy(args)
_args.shard_id = shard_id
_args.pod_role = PodRoleType.WORKER
# for gateway pods, the pod role shouldn't be changed
if _args.pod_role != PodRoleType.GATEWAY:
_args.pod_role = PodRoleType.WORKER

if cuda_device_map:
_args.env['CUDA_VISIBLE_DEVICES'] = str(cuda_device_map[replica_id])
Expand Down
23 changes: 17 additions & 6 deletions jina/orchestrate/pods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

__all__ = ['BasePod', 'Pod']

from jina.serve.runtimes.gateway import GatewayRuntime


def run(
args: 'argparse.Namespace',
Expand Down Expand Up @@ -188,12 +190,21 @@ def _wait_for_ready_or_shutdown(self, timeout: Optional[float]):
:param timeout: The time to wait before readiness or failure is determined
.. # noqa: DAR201
"""
return AsyncNewLoopRuntime.wait_for_ready_or_shutdown(
timeout=timeout,
ready_or_shutdown_event=self.ready_or_shutdown.event,
ctrl_address=self.runtime_ctrl_address,
timeout_ctrl=self._timeout_ctrl,
)
if self.args.pod_role == PodRoleType.GATEWAY:
return GatewayRuntime.wait_for_ready_or_shutdown(
timeout=timeout,
ready_or_shutdown_event=self.ready_or_shutdown.event,
ctrl_address=self.runtime_ctrl_address,
timeout_ctrl=self._timeout_ctrl,
protocol=self.args.protocol,
)
else:
return AsyncNewLoopRuntime.wait_for_ready_or_shutdown(
timeout=timeout,
ready_or_shutdown_event=self.ready_or_shutdown.event,
ctrl_address=self.runtime_ctrl_address,
timeout_ctrl=self._timeout_ctrl,
)

def _fail_start_timeout(self, timeout):
"""
Expand Down
90 changes: 15 additions & 75 deletions jina/serve/networking.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from jina import __default_endpoint__
from jina.enums import PollingType
from jina.excepts import EstablishGrpcConnectionError
from jina.excepts import EstablishGrpcConnectionError, InternalNetworkError
from jina.importer import ImportExtensions
from jina.logging.logger import JinaLogger
from jina.proto import jina_pb2, jina_pb2_grpc
Expand All @@ -26,7 +26,7 @@

from typing import TYPE_CHECKING

if TYPE_CHECKING: # pragma: no cover
if TYPE_CHECKING: # pragma: no cover
from grpc.aio._interceptor import ClientInterceptor
from opentelemetry.instrumentation.grpc._client import (
OpenTelemetryClientInterceptor,
Expand Down Expand Up @@ -755,42 +755,6 @@ def __init__(
)
self._deployment_address_map = {}

def send_request(
self,
request: Request,
deployment: str,
head: bool = False,
shard_id: Optional[int] = None,
polling_type: PollingType = PollingType.ANY,
endpoint: Optional[str] = None,
metadata: Optional[Dict[str, str]] = None,
timeout: Optional[float] = None,
retries: Optional[int] = -1,
) -> List[asyncio.Task]:
"""Send a single message to target via one or all of the pooled connections, depending on polling_type. Convenience function wrapper around send_request.
:param request: a single request to send
:param deployment: name of the Jina deployment to send the message to
:param head: If True it is send to the head, otherwise to the worker pods
:param shard_id: Send to a specific shard of the deployment, ignored for polling ALL
:param polling_type: defines if the message should be send to any or all pooled connections for the target
:param endpoint: endpoint to target with the request
:param metadata: metadata to send with the request
:param timeout: timeout for sending the requests
:param retries: number of retries per gRPC call. If <0 it defaults to max(3, num_replicas)
:return: list of asyncio.Task items for each send call
"""
return self.send_requests(
requests=[request],
deployment=deployment,
head=head,
shard_id=shard_id,
polling_type=polling_type,
endpoint=endpoint,
metadata=metadata,
timeout=timeout,
retries=retries,
)

def send_requests(
self,
requests: List[Request],
Expand Down Expand Up @@ -872,36 +836,6 @@ def send_discover_endpoint(
)
return None

def send_request_once(
self,
request: Request,
deployment: str,
metadata: Optional[Dict[str, str]] = None,
head: bool = False,
shard_id: Optional[int] = None,
timeout: Optional[float] = None,
retries: Optional[int] = -1,
) -> asyncio.Task:
"""Send msg to target via only one of the pooled connections
:param request: request to send
:param deployment: name of the Jina deployment to send the message to
:param metadata: metadata to send with the request
:param head: If True it is send to the head, otherwise to the worker pods
:param shard_id: Send to a specific shard of the deployment, ignored for polling ALL
:param timeout: timeout for sending the requests
:param retries: number of retries per gRPC call. If <0 it defaults to max(3, num_replicas)
:return: asyncio.Task representing the send call
"""
return self.send_requests_once(
[request],
deployment=deployment,
metadata=metadata,
head=head,
shard_id=shard_id,
timeout=timeout,
retries=retries,
)

def send_requests_once(
self,
requests: List[Request],
Expand All @@ -927,14 +861,15 @@ def send_requests_once(
"""
replicas = self._connections.get_replicas(deployment, head, shard_id)
if replicas:
return self._send_requests(
result = self._send_requests(
requests,
replicas,
endpoint=endpoint,
metadata=metadata,
timeout=timeout,
retries=retries,
)
return result
else:
self._logger.debug(
f'no available connections for deployment {deployment} and shard {shard_id}'
Expand Down Expand Up @@ -1005,7 +940,7 @@ async def _handle_aiorpcerror(
current_address: str = '', # the specific address that was contacted during this attempt
current_deployment: str = '', # the specific deployment that was contacted during this attempt
connection_list: Optional[ReplicaList] = None,
):
) -> 'Optional[Union[AioRpcError, InternalNetworkError]]':
# connection failures, cancelled requests, and timed out requests should be retried
# all other cases should not be retried and will be raised immediately
# connection failures have the code grpc.StatusCode.UNAVAILABLE
Expand All @@ -1018,7 +953,7 @@ async def _handle_aiorpcerror(
and error.code() != grpc.StatusCode.CANCELLED
and error.code() != grpc.StatusCode.DEADLINE_EXCEEDED
):
raise
return error
elif (
error.code() == grpc.StatusCode.UNAVAILABLE
or error.code() == grpc.StatusCode.DEADLINE_EXCEEDED
Expand All @@ -1031,7 +966,7 @@ async def _handle_aiorpcerror(
if connection_list:
await connection_list.reset_connection(current_address, current_deployment)

raise InternalNetworkError(
return InternalNetworkError(
og_exception=error,
request_id=request_id,
dest_addr=tried_addresses,
Expand All @@ -1042,6 +977,7 @@ async def _handle_aiorpcerror(
f'GRPC call failed with code {error.code()}, retry attempt {retry_i + 1}/{total_num_tries - 1}.'
f' Trying next replica, if available.'
)
return None

def _send_requests(
self,
Expand All @@ -1051,7 +987,7 @@ def _send_requests(
metadata: Optional[Dict[str, str]] = None,
timeout: Optional[float] = None,
retries: Optional[int] = -1,
) -> asyncio.Task:
) -> 'asyncio.Task[Union[Tuple, AioRpcError, InternalNetworkError]]':
# this wraps the awaitable object from grpc as a coroutine so it can be used as a task
# the grpc call function is not a coroutine but some _AioCall

Expand Down Expand Up @@ -1084,7 +1020,7 @@ async def task_wrapper():
timeout=timeout,
)
except AioRpcError as e:
await self._handle_aiorpcerror(
error = await self._handle_aiorpcerror(
error=e,
retry_i=i,
request_id=requests[0].request_id,
Expand All @@ -1094,6 +1030,8 @@ async def task_wrapper():
current_deployment=current_connection.deployment_name,
connection_list=connections,
)
if error:
return error

return asyncio.create_task(task_wrapper())

Expand Down Expand Up @@ -1128,7 +1066,7 @@ async def task_wrapper():
timeout=timeout,
)
except AioRpcError as e:
await self._handle_aiorpcerror(
error = await self._handle_aiorpcerror(
error=e,
retry_i=i,
tried_addresses=tried_addresses,
Expand All @@ -1137,6 +1075,8 @@ async def task_wrapper():
connection_list=connection_list,
total_num_tries=total_num_tries,
)
if error:
raise error
except AttributeError:
return default_endpoints_proto, None

Expand Down
11 changes: 9 additions & 2 deletions jina/serve/runtimes/gateway/graph/topology_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@
from typing import Dict, List, Optional, Tuple

import grpc.aio
from grpc.aio import AioRpcError

from jina import __default_endpoint__
from jina.excepts import InternalNetworkError
from jina.serve.networking import GrpcConnectionPool
from jina.serve.runtimes.helper import _parse_specific_params
from jina.serve.runtimes.request_handlers.worker_request_handler import WorkerRequestHandler
from jina.serve.runtimes.request_handlers.worker_request_handler import (
WorkerRequestHandler,
)
from jina.types.request.data import DataRequest


Expand Down Expand Up @@ -157,7 +160,7 @@ async def _wait_previous_and_send(
return request, metadata
# otherwise, send to executor and get response
try:
resp, metadata = await connection_pool.send_requests_once(
result = await connection_pool.send_requests_once(
requests=self.parts_to_send,
deployment=self.name,
metadata=self._metadata,
Expand All @@ -166,6 +169,10 @@ async def _wait_previous_and_send(
timeout=self._timeout_send,
retries=self._retries,
)
if issubclass(type(result), BaseException):
raise result
else:
resp, metadata = result
if WorkerRequestHandler._KEY_RESULT in resp.parameters:
# Accumulate results from each Node and then add them to the original
self.result_in_params_returned = resp.parameters[
Expand Down
Loading

0 comments on commit 38bec69

Please sign in to comment.