diff --git a/chromadb/execution/executor/distributed.py b/chromadb/execution/executor/distributed.py index 0d91c882c043..a2b9359a08bb 100644 --- a/chromadb/execution/executor/distributed.py +++ b/chromadb/execution/executor/distributed.py @@ -1,6 +1,6 @@ import threading import random -from typing import Dict, Optional +from typing import Callable, Dict, List, Optional, TypeVar import grpc from overrides import overrides from chromadb.api.types import GetResult, Metadata, QueryResult @@ -12,6 +12,14 @@ from chromadb.proto.query_executor_pb2_grpc import QueryExecutorStub from chromadb.segment.impl.manager.distributed import DistributedSegmentManager from chromadb.telemetry.opentelemetry.grpc import OtelInterceptor +from tenacity import ( + RetryCallState, + Retrying, + stop_after_attempt, + wait_exponential_jitter, + retry_if_exception, +) +from opentelemetry.trace import Span def _clean_metadata(metadata: Optional[Metadata]) -> Optional[Metadata]: @@ -35,6 +43,10 @@ def _uri(metadata: Optional[Metadata]) -> Optional[str]: return None +I = TypeVar("I") # noqa: E741 +O = TypeVar("O") # noqa: E741 + + class DistributedExecutor(Executor): _mtx: threading.Lock _grpc_stub_pool: Dict[str, QueryExecutorStub] @@ -45,31 +57,64 @@ class DistributedExecutor(Executor): def __init__(self, system: System): super().__init__(system) self._mtx = threading.Lock() - self._grpc_stub_pool = dict() + self._grpc_stub_pool = {} self._manager = self.require(DistributedSegmentManager) self._request_timeout_seconds = system.settings.require( "chroma_query_request_timeout_seconds" ) self._query_replication_factor = system.settings.require( - "chroma_query_replication" + "chroma_query_replication_factor" ) + def _round_robin_retry(self, funcs: List[Callable[[I], O]], args: I) -> O: + attempt_count = 0 + sleep_span: Optional[Span] = None + + def before_sleep(_: RetryCallState) -> None: + # HACK(hammadb) 1/14/2024 - this is a hack to avoid the fact that tracer is not yet available and there are boot order issues + # This should really use our component system to get the tracer. Since our grpc utils use this pattern + # we are copying it here. This should be removed once we have a better way to get the tracer + from chromadb.telemetry.opentelemetry import tracer + + nonlocal sleep_span + if tracer is not None: + sleep_span = tracer.start_span("Waiting to retry RPC") + + for attempt in Retrying( + stop=stop_after_attempt(5), + wait=wait_exponential_jitter(0.1, jitter=0.1), + reraise=True, + retry=retry_if_exception( + lambda x: isinstance(x, grpc.RpcError) + and x.code() in [grpc.StatusCode.UNAVAILABLE, grpc.StatusCode.UNKNOWN] + ), + before_sleep=before_sleep, + ): + if sleep_span is not None: + sleep_span.end() + sleep_span = None + + with attempt: + return funcs[attempt_count % len(funcs)](args) + attempt_count += 1 + + # NOTE(hammadb) because Retrying() will always either return or raise an exception, this line should never be reached + raise Exception("Unreachable code error - should never reach here") + @overrides def count(self, plan: CountPlan) -> int: - executor = self._grpc_executor_stub(plan.scan) - try: - count_result = executor.Count(convert.to_proto_count_plan(plan)) - except grpc.RpcError as rpc_error: - raise rpc_error + endpoints = self._get_grpc_endpoints(plan.scan) + count_funcs = [self._get_stub(endpoint).Count for endpoint in endpoints] + count_result = self._round_robin_retry( + count_funcs, convert.to_proto_count_plan(plan) + ) return convert.from_proto_count_result(count_result) @overrides def get(self, plan: GetPlan) -> GetResult: - executor = self._grpc_executor_stub(plan.scan) - try: - get_result = executor.Get(convert.to_proto_get_plan(plan)) - except grpc.RpcError as rpc_error: - raise rpc_error + endpoints = self._get_grpc_endpoints(plan.scan) + get_funcs = [self._get_stub(endpoint).Get for endpoint in endpoints] + get_result = self._round_robin_retry(get_funcs, convert.to_proto_get_plan(plan)) records = convert.from_proto_get_result(get_result) ids = [record["id"] for record in records] @@ -107,11 +152,9 @@ def get(self, plan: GetPlan) -> GetResult: @overrides def knn(self, plan: KNNPlan) -> QueryResult: - executor = self._grpc_executor_stub(plan.scan) - try: - knn_result = executor.KNN(convert.to_proto_knn_plan(plan)) - except grpc.RpcError as rpc_error: - raise rpc_error + endpoints = self._get_grpc_endpoints(plan.scan) + knn_funcs = [self._get_stub(endpoint).KNN for endpoint in endpoints] + knn_result = self._round_robin_retry(knn_funcs, convert.to_proto_knn_plan(plan)) results = convert.from_proto_knn_batch_result(knn_result) ids = [[record["record"]["id"] for record in records] for records in results] @@ -165,17 +208,20 @@ def knn(self, plan: KNNPlan) -> QueryResult: included=plan.projection.included, ) - def _grpc_executor_stub(self, scan: Scan) -> QueryExecutorStub: + def _get_grpc_endpoints(self, scan: Scan) -> List[str]: # Since grpc endpoint is endpoint is determined by collection uuid, # the endpoint should be the same for all segments of the same collection - grpc_url = self._manager.get_endpoint(scan.record, self._query_replication_factor) - with self._mtx: - if grpc_url not in self._grpc_stub_pool: - channel = grpc.insecure_channel( - grpc_url, options=[("grpc.max_concurrent_streams", 1000)] - ) - interceptors = [OtelInterceptor()] - channel = grpc.intercept_channel(channel, *interceptors) - self._grpc_stub_pool[grpc_url] = QueryExecutorStub(channel) - - return self._grpc_stub_pool[grpc_url] + grpc_urls = self._manager.get_endpoints( + scan.record, self._query_replication_factor + ) + # Shuffle the grpc urls to distribute the load evenly + random.shuffle(grpc_urls) + return grpc_urls + + def _get_stub(self, grpc_url: str) -> QueryExecutorStub: + if grpc_url not in self._grpc_stub_pool: + channel = grpc.insecure_channel(grpc_url, options=[("grpc.max_concurrent_streams", 1000)]) + interceptors = [OtelInterceptor()] + channel = grpc.intercept_channel(channel, *interceptors) + self._grpc_stub_pool[grpc_url] = QueryExecutorStub(channel) # type: ignore[no-untyped-call] + return self._grpc_stub_pool[grpc_url] diff --git a/chromadb/segment/impl/distributed/segment_directory.py b/chromadb/segment/impl/distributed/segment_directory.py index 97b4b13c9ca0..1366052e3748 100644 --- a/chromadb/segment/impl/distributed/segment_directory.py +++ b/chromadb/segment/impl/distributed/segment_directory.py @@ -270,10 +270,11 @@ def get_segment_endpoints(self, segment: Segment, n: int) -> List[str]: # We want to route using the node name over the member id # because the node may have a disk cache that we want a # stable identifier for over deploys. - can_use_node_routing = all( - [m.node != "" and len(m.node) != 0 for m in self._curr_memberlist] + can_use_node_routing = ( + all([m.node != "" and len(m.node) != 0 for m in self._curr_memberlist]) + and self._routing_mode == RoutingMode.NODE ) - if can_use_node_routing and self._routing_mode == RoutingMode.NODE: + if can_use_node_routing: # If we are using node routing and the segments assignments = assign( segment["collection"].hex, @@ -289,7 +290,6 @@ def get_segment_endpoints(self, segment: Segment, n: int) -> List[str]: murmur3hasher, n, ) - assignments_set = set(assignments) out_endpoints = [] for member in self._curr_memberlist: