Skip to content

Commit

Permalink
Add retry round robin
Browse files Browse the repository at this point in the history
  • Loading branch information
HammadB committed Jan 15, 2025
1 parent 317e629 commit 7fadd8c
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 34 deletions.
106 changes: 76 additions & 30 deletions chromadb/execution/executor/distributed.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import threading
import random
from typing import Dict, Optional
from typing import Callable, Dict, List, Optional, TypeVar
import grpc
from overrides import overrides
from chromadb.api.types import GetResult, Metadata, QueryResult
Expand All @@ -12,6 +12,14 @@
from chromadb.proto.query_executor_pb2_grpc import QueryExecutorStub
from chromadb.segment.impl.manager.distributed import DistributedSegmentManager
from chromadb.telemetry.opentelemetry.grpc import OtelInterceptor
from tenacity import (
RetryCallState,
Retrying,
stop_after_attempt,
wait_exponential_jitter,
retry_if_exception,
)
from opentelemetry.trace import Span


def _clean_metadata(metadata: Optional[Metadata]) -> Optional[Metadata]:
Expand All @@ -35,6 +43,10 @@ def _uri(metadata: Optional[Metadata]) -> Optional[str]:
return None


I = TypeVar("I") # noqa: E741
O = TypeVar("O") # noqa: E741


class DistributedExecutor(Executor):
_mtx: threading.Lock
_grpc_stub_pool: Dict[str, QueryExecutorStub]
Expand All @@ -45,31 +57,64 @@ class DistributedExecutor(Executor):
def __init__(self, system: System):
super().__init__(system)
self._mtx = threading.Lock()
self._grpc_stub_pool = dict()
self._grpc_stub_pool = {}
self._manager = self.require(DistributedSegmentManager)
self._request_timeout_seconds = system.settings.require(
"chroma_query_request_timeout_seconds"
)
self._query_replication_factor = system.settings.require(
"chroma_query_replication"
"chroma_query_replication_factor"
)

def _round_robin_retry(self, funcs: List[Callable[[I], O]], args: I) -> O:
attempt_count = 0
sleep_span: Optional[Span] = None

def before_sleep(_: RetryCallState) -> None:
# HACK(hammadb) 1/14/2024 - this is a hack to avoid the fact that tracer is not yet available and there are boot order issues
# This should really use our component system to get the tracer. Since our grpc utils use this pattern
# we are copying it here. This should be removed once we have a better way to get the tracer
from chromadb.telemetry.opentelemetry import tracer

nonlocal sleep_span
if tracer is not None:
sleep_span = tracer.start_span("Waiting to retry RPC")

for attempt in Retrying(
stop=stop_after_attempt(5),
wait=wait_exponential_jitter(0.1, jitter=0.1),
reraise=True,
retry=retry_if_exception(
lambda x: isinstance(x, grpc.RpcError)
and x.code() in [grpc.StatusCode.UNAVAILABLE, grpc.StatusCode.UNKNOWN]
),
before_sleep=before_sleep,
):
if sleep_span is not None:
sleep_span.end()
sleep_span = None

with attempt:
return funcs[attempt_count % len(funcs)](args)
attempt_count += 1

# NOTE(hammadb) because Retrying() will always either return or raise an exception, this line should never be reached
raise Exception("Unreachable code error - should never reach here")

@overrides
def count(self, plan: CountPlan) -> int:
executor = self._grpc_executor_stub(plan.scan)
try:
count_result = executor.Count(convert.to_proto_count_plan(plan))
except grpc.RpcError as rpc_error:
raise rpc_error
endpoints = self._get_grpc_endpoints(plan.scan)
count_funcs = [self._get_stub(endpoint).Count for endpoint in endpoints]
count_result = self._round_robin_retry(
count_funcs, convert.to_proto_count_plan(plan)
)
return convert.from_proto_count_result(count_result)

@overrides
def get(self, plan: GetPlan) -> GetResult:
executor = self._grpc_executor_stub(plan.scan)
try:
get_result = executor.Get(convert.to_proto_get_plan(plan))
except grpc.RpcError as rpc_error:
raise rpc_error
endpoints = self._get_grpc_endpoints(plan.scan)
get_funcs = [self._get_stub(endpoint).Get for endpoint in endpoints]
get_result = self._round_robin_retry(get_funcs, convert.to_proto_get_plan(plan))
records = convert.from_proto_get_result(get_result)

ids = [record["id"] for record in records]
Expand Down Expand Up @@ -107,11 +152,9 @@ def get(self, plan: GetPlan) -> GetResult:

@overrides
def knn(self, plan: KNNPlan) -> QueryResult:
executor = self._grpc_executor_stub(plan.scan)
try:
knn_result = executor.KNN(convert.to_proto_knn_plan(plan))
except grpc.RpcError as rpc_error:
raise rpc_error
endpoints = self._get_grpc_endpoints(plan.scan)
knn_funcs = [self._get_stub(endpoint).KNN for endpoint in endpoints]
knn_result = self._round_robin_retry(knn_funcs, convert.to_proto_knn_plan(plan))
results = convert.from_proto_knn_batch_result(knn_result)

ids = [[record["record"]["id"] for record in records] for records in results]
Expand Down Expand Up @@ -165,17 +208,20 @@ def knn(self, plan: KNNPlan) -> QueryResult:
included=plan.projection.included,
)

def _grpc_executor_stub(self, scan: Scan) -> QueryExecutorStub:
def _get_grpc_endpoints(self, scan: Scan) -> List[str]:
# Since grpc endpoint is endpoint is determined by collection uuid,
# the endpoint should be the same for all segments of the same collection
grpc_url = self._manager.get_endpoint(scan.record, self._query_replication_factor)
with self._mtx:
if grpc_url not in self._grpc_stub_pool:
channel = grpc.insecure_channel(
grpc_url, options=[("grpc.max_concurrent_streams", 1000)]
)
interceptors = [OtelInterceptor()]
channel = grpc.intercept_channel(channel, *interceptors)
self._grpc_stub_pool[grpc_url] = QueryExecutorStub(channel)

return self._grpc_stub_pool[grpc_url]
grpc_urls = self._manager.get_endpoints(
scan.record, self._query_replication_factor
)
# Shuffle the grpc urls to distribute the load evenly
random.shuffle(grpc_urls)
return grpc_urls

def _get_stub(self, grpc_url: str) -> QueryExecutorStub:
if grpc_url not in self._grpc_stub_pool:
channel = grpc.insecure_channel(grpc_url, options=[("grpc.max_concurrent_streams", 1000)])
interceptors = [OtelInterceptor()]
channel = grpc.intercept_channel(channel, *interceptors)
self._grpc_stub_pool[grpc_url] = QueryExecutorStub(channel) # type: ignore[no-untyped-call]
return self._grpc_stub_pool[grpc_url]
8 changes: 4 additions & 4 deletions chromadb/segment/impl/distributed/segment_directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,10 +270,11 @@ def get_segment_endpoints(self, segment: Segment, n: int) -> List[str]:
# We want to route using the node name over the member id
# because the node may have a disk cache that we want a
# stable identifier for over deploys.
can_use_node_routing = all(
[m.node != "" and len(m.node) != 0 for m in self._curr_memberlist]
can_use_node_routing = (
all([m.node != "" and len(m.node) != 0 for m in self._curr_memberlist])
and self._routing_mode == RoutingMode.NODE
)
if can_use_node_routing and self._routing_mode == RoutingMode.NODE:
if can_use_node_routing:
# If we are using node routing and the segments
assignments = assign(
segment["collection"].hex,
Expand All @@ -289,7 +290,6 @@ def get_segment_endpoints(self, segment: Segment, n: int) -> List[str]:
murmur3hasher,
n,
)

assignments_set = set(assignments)
out_endpoints = []
for member in self._curr_memberlist:
Expand Down

0 comments on commit 7fadd8c

Please sign in to comment.