diff --git a/chromadb/config.py b/chromadb/config.py index 2b285e594ca..39cb1bc53e1 100644 --- a/chromadb/config.py +++ b/chromadb/config.py @@ -263,6 +263,7 @@ def empty_str_to_none(cls, v: str) -> Optional[str]: "chromadb.segment.impl.manager.local.LocalSegmentManager" ) chroma_executor_impl: str = "chromadb.execution.executor.local.LocalExecutor" + chroma_query_replication_factor: int = 2 chroma_logservice_host = "localhost" chroma_logservice_port = 50052 diff --git a/chromadb/execution/executor/distributed.py b/chromadb/execution/executor/distributed.py index baa691840ef..c77fe7158fe 100644 --- a/chromadb/execution/executor/distributed.py +++ b/chromadb/execution/executor/distributed.py @@ -39,6 +39,7 @@ class DistributedExecutor(Executor): _grpc_stub_pool: Dict[str, QueryExecutorStub] _manager: DistributedSegmentManager _request_timeout_seconds: int + _query_replication_factor: int def __init__(self, system: System): super().__init__(system) @@ -47,6 +48,9 @@ def __init__(self, system: System): self._request_timeout_seconds = system.settings.require( "chroma_query_request_timeout_seconds" ) + self._query_replication_factor = system.settings.require( + "chroma_query_replication" + ) @overrides def count(self, plan: CountPlan) -> int: @@ -162,8 +166,9 @@ def knn(self, plan: KNNPlan) -> QueryResult: def _grpc_executuor_stub(self, scan: Scan) -> QueryExecutorStub: # Since grpc endpoint is endpoint is determined by collection uuid, # the endpoint should be the same for all segments of the same collection - # TODO: configure the number of endpoints to fetch - grpc_urls = self._manager.get_endpoints(scan.record, 3) + grpc_urls = self._manager.get_endpoints( + scan.record, self._query_replication_factor + ) grpc_url = grpc_urls[random.randint(0, len(grpc_urls) - 1)] if grpc_url not in self._grpc_stub_pool: channel = grpc.insecure_channel(grpc_url) diff --git a/chromadb/segment/impl/distributed/segment_directory.py b/chromadb/segment/impl/distributed/segment_directory.py index 8df4afbbaf6..97b4b13c9ca 100644 --- a/chromadb/segment/impl/distributed/segment_directory.py +++ b/chromadb/segment/impl/distributed/segment_directory.py @@ -257,8 +257,14 @@ def get_segment_endpoints(self, segment: Segment, n: int) -> List[str]: if self._curr_memberlist is None or len(self._curr_memberlist) == 0: raise ValueError("Memberlist is not initialized") + # assign() will throw an error if n is greater than the number of members + # clamp n to the number of members to align with the contract of this method + # which is to return at most n endpoints + n = min(n, len(self._curr_memberlist)) + # Check if all members in the memberlist have a node set, # if so, route using the node + # NOTE(@hammadb) 1/8/2024: This is to handle the migration between routing # using the member id and routing using the node name # We want to route using the node name over the member id