diff --git a/chromadb/config.py b/chromadb/config.py index a773c698737..0557fab0abd 100644 --- a/chromadb/config.py +++ b/chromadb/config.py @@ -264,6 +264,7 @@ def empty_str_to_none(cls, v: str) -> Optional[str]: "chromadb.segment.impl.manager.local.LocalSegmentManager" ) chroma_executor_impl: str = "chromadb.execution.executor.local.LocalExecutor" + chroma_query_replication_factor: int = 2 chroma_logservice_host = "localhost" chroma_logservice_port = 50052 diff --git a/chromadb/execution/executor/distributed.py b/chromadb/execution/executor/distributed.py index ebc58721e6b..0d91c882c04 100644 --- a/chromadb/execution/executor/distributed.py +++ b/chromadb/execution/executor/distributed.py @@ -40,6 +40,7 @@ class DistributedExecutor(Executor): _grpc_stub_pool: Dict[str, QueryExecutorStub] _manager: DistributedSegmentManager _request_timeout_seconds: int + _query_replication_factor: int def __init__(self, system: System): super().__init__(system) @@ -49,6 +50,9 @@ def __init__(self, system: System): self._request_timeout_seconds = system.settings.require( "chroma_query_request_timeout_seconds" ) + self._query_replication_factor = system.settings.require( + "chroma_query_replication" + ) @overrides def count(self, plan: CountPlan) -> int: @@ -164,7 +168,7 @@ def knn(self, plan: KNNPlan) -> QueryResult: def _grpc_executor_stub(self, scan: Scan) -> QueryExecutorStub: # Since grpc endpoint is endpoint is determined by collection uuid, # the endpoint should be the same for all segments of the same collection - grpc_url = self._manager.get_endpoint(scan.record, 3) + grpc_url = self._manager.get_endpoint(scan.record, self._query_replication_factor) with self._mtx: if grpc_url not in self._grpc_stub_pool: channel = grpc.insecure_channel( diff --git a/chromadb/segment/impl/distributed/segment_directory.py b/chromadb/segment/impl/distributed/segment_directory.py index 8df4afbbaf6..97b4b13c9ca 100644 --- a/chromadb/segment/impl/distributed/segment_directory.py +++ b/chromadb/segment/impl/distributed/segment_directory.py @@ -257,8 +257,14 @@ def get_segment_endpoints(self, segment: Segment, n: int) -> List[str]: if self._curr_memberlist is None or len(self._curr_memberlist) == 0: raise ValueError("Memberlist is not initialized") + # assign() will throw an error if n is greater than the number of members + # clamp n to the number of members to align with the contract of this method + # which is to return at most n endpoints + n = min(n, len(self._curr_memberlist)) + # Check if all members in the memberlist have a node set, # if so, route using the node + # NOTE(@hammadb) 1/8/2024: This is to handle the migration between routing # using the member id and routing using the node name # We want to route using the node name over the member id