Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: capture shard failures in the head runtime #5338

Merged
merged 18 commits into from
Nov 8, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 17 additions & 75 deletions jina/serve/networking.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@
from grpc_health.v1 import health_pb2, health_pb2_grpc
from grpc_reflection.v1alpha.reflection_pb2 import ServerReflectionRequest
from grpc_reflection.v1alpha.reflection_pb2_grpc import ServerReflectionStub

from jina import __default_endpoint__
from jina.enums import PollingType
from jina.excepts import EstablishGrpcConnectionError
from jina.excepts import EstablishGrpcConnectionError, InternalNetworkError
from jina.importer import ImportExtensions
from jina.logging.logger import JinaLogger
from jina.proto import jina_pb2, jina_pb2_grpc
Expand All @@ -26,7 +25,7 @@

from typing import TYPE_CHECKING

if TYPE_CHECKING: # pragma: no cover
if TYPE_CHECKING: # pragma: no cover
from grpc.aio._interceptor import ClientInterceptor
from opentelemetry.instrumentation.grpc._client import (
OpenTelemetryClientInterceptor,
Expand Down Expand Up @@ -739,42 +738,6 @@ def __init__(
)
self._deployment_address_map = {}

def send_request(
self,
request: Request,
deployment: str,
head: bool = False,
shard_id: Optional[int] = None,
polling_type: PollingType = PollingType.ANY,
endpoint: Optional[str] = None,
metadata: Optional[Dict[str, str]] = None,
timeout: Optional[float] = None,
retries: Optional[int] = -1,
) -> List[asyncio.Task]:
"""Send a single message to target via one or all of the pooled connections, depending on polling_type. Convenience function wrapper around send_request.
:param request: a single request to send
:param deployment: name of the Jina deployment to send the message to
:param head: If True it is send to the head, otherwise to the worker pods
:param shard_id: Send to a specific shard of the deployment, ignored for polling ALL
:param polling_type: defines if the message should be send to any or all pooled connections for the target
:param endpoint: endpoint to target with the request
:param metadata: metadata to send with the request
:param timeout: timeout for sending the requests
:param retries: number of retries per gRPC call. If <0 it defaults to max(3, num_replicas)
:return: list of asyncio.Task items for each send call
"""
return self.send_requests(
requests=[request],
deployment=deployment,
head=head,
shard_id=shard_id,
polling_type=polling_type,
endpoint=endpoint,
metadata=metadata,
timeout=timeout,
retries=retries,
)

def send_requests(
self,
requests: List[Request],
Expand Down Expand Up @@ -856,36 +819,6 @@ def send_discover_endpoint(
)
return None

def send_request_once(
self,
request: Request,
deployment: str,
metadata: Optional[Dict[str, str]] = None,
head: bool = False,
shard_id: Optional[int] = None,
timeout: Optional[float] = None,
retries: Optional[int] = -1,
) -> asyncio.Task:
"""Send msg to target via only one of the pooled connections
:param request: request to send
:param deployment: name of the Jina deployment to send the message to
:param metadata: metadata to send with the request
:param head: If True it is send to the head, otherwise to the worker pods
:param shard_id: Send to a specific shard of the deployment, ignored for polling ALL
:param timeout: timeout for sending the requests
:param retries: number of retries per gRPC call. If <0 it defaults to max(3, num_replicas)
:return: asyncio.Task representing the send call
"""
return self.send_requests_once(
[request],
deployment=deployment,
metadata=metadata,
head=head,
shard_id=shard_id,
timeout=timeout,
retries=retries,
)

def send_requests_once(
self,
requests: List[Request],
Expand All @@ -911,14 +844,18 @@ def send_requests_once(
"""
replicas = self._connections.get_replicas(deployment, head, shard_id)
if replicas:
return self._send_requests(
result = self._send_requests(
requests,
replicas,
endpoint=endpoint,
metadata=metadata,
timeout=timeout,
retries=retries,
)
if isinstance(result, (AioRpcError, InternalNetworkError)):
raise result
else:
return result
else:
self._logger.debug(
f'no available connections for deployment {deployment} and shard {shard_id}'
Expand Down Expand Up @@ -988,7 +925,7 @@ async def _handle_aiorpcerror(
total_num_tries: int = 1, # number of retries + 1
current_address: str = '', # the specific address that was contacted during this attempt
connection_list: Optional[ReplicaList] = None,
):
) -> Optional[Union[AioRpcError, InternalNetworkError]]:
# connection failures, cancelled requests, and timed out requests should be retried
# all other cases should not be retried and will be raised immediately
# connection failures have the code grpc.StatusCode.UNAVAILABLE
Expand All @@ -1001,7 +938,7 @@ async def _handle_aiorpcerror(
and error.code() != grpc.StatusCode.CANCELLED
and error.code() != grpc.StatusCode.DEADLINE_EXCEEDED
):
raise
return error
elif (
error.code() == grpc.StatusCode.UNAVAILABLE
or error.code() == grpc.StatusCode.DEADLINE_EXCEEDED
Expand All @@ -1014,7 +951,7 @@ async def _handle_aiorpcerror(
if connection_list:
await connection_list.reset_connection(current_address)

raise InternalNetworkError(
return InternalNetworkError(
og_exception=error,
request_id=request_id,
dest_addr=tried_addresses,
Expand All @@ -1025,6 +962,7 @@ async def _handle_aiorpcerror(
f'GRPC call failed with code {error.code()}, retry attempt {retry_i + 1}/{total_num_tries - 1}.'
f' Trying next replica, if available.'
)
return None

def _send_requests(
self,
Expand Down Expand Up @@ -1067,7 +1005,7 @@ async def task_wrapper():
timeout=timeout,
)
except AioRpcError as e:
await self._handle_aiorpcerror(
error = await self._handle_aiorpcerror(
error=e,
retry_i=i,
request_id=requests[0].request_id,
Expand All @@ -1076,6 +1014,8 @@ async def task_wrapper():
current_address=current_connection.address,
connection_list=connections,
)
if error:
girishc13 marked this conversation as resolved.
Show resolved Hide resolved
return error
JohannesMessner marked this conversation as resolved.
Show resolved Hide resolved

return asyncio.create_task(task_wrapper())

Expand Down Expand Up @@ -1110,14 +1050,16 @@ async def task_wrapper():
timeout=timeout,
)
except AioRpcError as e:
await self._handle_aiorpcerror(
error = await self._handle_aiorpcerror(
error=e,
retry_i=i,
tried_addresses=tried_addresses,
current_address=connection.address,
connection_list=connection_list,
total_num_tries=total_num_tries,
)
if error:
alaeddine-13 marked this conversation as resolved.
Show resolved Hide resolved
raise error
except AttributeError:
return default_endpoints_proto, None

Expand Down
11 changes: 9 additions & 2 deletions jina/serve/runtimes/gateway/graph/topology_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@
from typing import Dict, List, Optional, Tuple

import grpc.aio
from grpc.aio import AioRpcError

from jina import __default_endpoint__
from jina.excepts import InternalNetworkError
from jina.serve.networking import GrpcConnectionPool
from jina.serve.runtimes.helper import _parse_specific_params
from jina.serve.runtimes.request_handlers.worker_request_handler import WorkerRequestHandler
from jina.serve.runtimes.request_handlers.worker_request_handler import (
WorkerRequestHandler,
)
from jina.types.request.data import DataRequest


Expand Down Expand Up @@ -157,7 +160,7 @@ async def _wait_previous_and_send(
return request, metadata
# otherwise, send to executor and get response
try:
resp, metadata = await connection_pool.send_requests_once(
result = await connection_pool.send_requests_once(
requests=self.parts_to_send,
deployment=self.name,
metadata=self._metadata,
Expand All @@ -166,6 +169,10 @@ async def _wait_previous_and_send(
timeout=self._timeout_send,
retries=self._retries,
)
if isinstance(result, (AioRpcError, InternalNetworkError)):
raise result
else:
resp, metadata = result
girishc13 marked this conversation as resolved.
Show resolved Hide resolved
if WorkerRequestHandler._KEY_RESULT in resp.parameters:
# Accumulate results from each Node and then add them to the original
self.result_in_params_returned = resp.parameters[
Expand Down
68 changes: 51 additions & 17 deletions jina/serve/runtimes/head/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import argparse
import asyncio
import contextlib
import json
import os
from abc import ABC
from collections import defaultdict
from typing import Dict, List, Optional, Tuple

import grpc
from grpc.aio import AioRpcError
from grpc_health.v1 import health, health_pb2, health_pb2_grpc
from grpc_reflection.v1alpha import reflection

from jina.enums import PollingType
from jina.excepts import InternalNetworkError
from jina.helper import get_full_version
Expand All @@ -20,7 +19,9 @@
from jina.serve.networking import GrpcConnectionPool
from jina.serve.runtimes.asyncio import AsyncNewLoopRuntime
from jina.serve.runtimes.helper import _get_grpc_server_options
from jina.serve.runtimes.request_handlers.worker_request_handler import WorkerRequestHandler
from jina.serve.runtimes.request_handlers.worker_request_handler import (
WorkerRequestHandler,
)
from jina.types.request.data import DataRequest, Response


Expand Down Expand Up @@ -311,16 +312,18 @@ async def _handle_data_request(

uses_before_metadata = None
if self.uses_before_address:
(
response,
uses_before_metadata,
) = await self.connection_pool.send_requests_once(
uses_before_result = await self.connection_pool.send_requests_once(
requests,
deployment='uses_before',
timeout=self.timeout_send,
retries=self._retries,
)
requests = [response]

if isinstance(uses_before_result, (AioRpcError, InternalNetworkError)):
raise uses_before_result
alaeddine-13 marked this conversation as resolved.
Show resolved Hide resolved
else:
(response, uses_before_metadata) = uses_before_result
requests = [response]

worker_send_tasks = self.connection_pool.send_requests(
requests=requests,
Expand All @@ -329,10 +332,26 @@ async def _handle_data_request(
timeout=self.timeout_send,
retries=self._retries,
)
total_shards = len(worker_send_tasks)

worker_results = await asyncio.gather(*worker_send_tasks)
all_worker_results = await asyncio.gather(*worker_send_tasks)
worker_results = list(
filter(lambda x: isinstance(x, Tuple), all_worker_results)
)
exceptions = list(
filter(
lambda x: isinstance(x, (AioRpcError, InternalNetworkError)),
all_worker_results,
)
)
failed_shards = len(exceptions)
if failed_shards:
girishc13 marked this conversation as resolved.
Show resolved Hide resolved
self.logger.warning(f'{failed_shards} shards out of {total_shards} failed.')

if len(worker_results) == 0:
if exceptions:
# raise the underlying error first
raise exceptions[0]
raise RuntimeError(
f'Head {self.name} did not receive a response when sending message to worker pods'
)
Expand All @@ -342,31 +361,43 @@ async def _handle_data_request(
response_request = worker_results[0]
uses_after_metadata = None
if self.uses_after_address:
(
response_request,
uses_after_metadata,
) = await self.connection_pool.send_requests_once(
uses_after_result = await self.connection_pool.send_requests_once(
worker_results,
deployment='uses_after',
timeout=self.timeout_send,
retries=self._retries,
)
if isinstance(uses_after_result, (AioRpcError, InternalNetworkError)):
raise uses_after_result
alaeddine-13 marked this conversation as resolved.
Show resolved Hide resolved
else:
(response_request, uses_after_metadata) = uses_after_result
elif len(worker_results) > 1 and self._reduce:
WorkerRequestHandler.reduce_requests(worker_results)
response_request = WorkerRequestHandler.reduce_requests(worker_results)
alaeddine-13 marked this conversation as resolved.
Show resolved Hide resolved
elif len(worker_results) > 1 and not self._reduce:
# worker returned multiple responsed, but the head is configured to skip reduction
# worker returned multiple responses, but the head is configured to skip reduction
# just concatenate the docs in this case
response_request.data.docs = WorkerRequestHandler.get_docs_from_request(
requests, field='docs'
)

merged_metadata = self._merge_metadata(
metadata, uses_after_metadata, uses_before_metadata
metadata,
uses_after_metadata,
uses_before_metadata,
total_shards,
failed_shards,
)

return response_request, merged_metadata

def _merge_metadata(self, metadata, uses_after_metadata, uses_before_metadata):
def _merge_metadata(
self,
metadata,
uses_after_metadata,
uses_before_metadata,
total_shards,
failed_shards,
):
merged_metadata = {}
if uses_before_metadata:
for key, value in uses_before_metadata:
Expand All @@ -377,6 +408,9 @@ def _merge_metadata(self, metadata, uses_after_metadata, uses_before_metadata):
if uses_after_metadata:
for key, value in uses_after_metadata:
merged_metadata[key] = value

merged_metadata['total_shards'] = str(total_shards)
merged_metadata['failed_shards'] = str(failed_shards)
return merged_metadata

async def _status(self, empty, context) -> jina_pb2.JinaInfoProto:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ def reduce(docs_matrix: List['DocumentArray']) -> Optional['DocumentArray']:
@staticmethod
def reduce_requests(requests: List['DataRequest']) -> 'DataRequest':
"""
Reduces a list of requests containing DocumentArrays inton one request object. Changes are applied to the first
Reduces a list of requests containing DocumentArrays into one request object. Changes are applied to the first
request object in-place.

Reduction consists in reducing every DocumentArray in `requests` sequentially using
Expand Down
Loading