diff --git a/llama-index-core/llama_index/core/indices/property_graph/sub_retrievers/llm_synonym.py b/llama-index-core/llama_index/core/indices/property_graph/sub_retrievers/llm_synonym.py index 17938c1ed77a5..4248af17fc41b 100644 --- a/llama-index-core/llama_index/core/indices/property_graph/sub_retrievers/llm_synonym.py +++ b/llama-index-core/llama_index/core/indices/property_graph/sub_retrievers/llm_synonym.py @@ -59,6 +59,7 @@ def __init__( ] = DEFAULT_SYNONYM_EXPAND_TEMPLATE, max_keywords: int = 10, path_depth: int = 1, + limit: int = 30, output_parsing_fn: Optional[Callable] = None, llm: Optional[LLM] = None, **kwargs: Any, @@ -70,6 +71,7 @@ def __init__( self._output_parsing_fn = output_parsing_fn self._max_keywords = max_keywords self._path_depth = path_depth + self._limit = limit super().__init__( graph_store=graph_store, include_text=include_text, @@ -86,27 +88,35 @@ def _parse_llm_output(self, output: str) -> List[str]: # capitalize to normalize with ingestion return [x.strip().capitalize() for x in matches if x.strip()] - def _prepare_matches(self, matches: List[str]) -> List[NodeWithScore]: + def _prepare_matches( + self, matches: List[str], limit: Optional[int] = None + ) -> List[NodeWithScore]: kg_nodes = self._graph_store.get(ids=matches) triplets = self._graph_store.get_rel_map( kg_nodes, depth=self._path_depth, + limit=limit or self._limit, ignore_rels=[KG_SOURCE_REL], ) return self._get_nodes_with_score(triplets) - async def _aprepare_matches(self, matches: List[str]) -> List[NodeWithScore]: + async def _aprepare_matches( + self, matches: List[str], limit: Optional[int] = None + ) -> List[NodeWithScore]: kg_nodes = await self._graph_store.aget(ids=matches) triplets = await self._graph_store.aget_rel_map( kg_nodes, depth=self._path_depth, + limit=limit or self._limit, ignore_rels=[KG_SOURCE_REL], ) return self._get_nodes_with_score(triplets) - def retrieve_from_graph(self, query_bundle: QueryBundle) -> List[NodeWithScore]: + def retrieve_from_graph( + self, query_bundle: QueryBundle, limit: Optional[int] = None + ) -> List[NodeWithScore]: response = self._llm.predict( self._synonym_prompt, query_str=query_bundle.query_str, @@ -114,10 +124,10 @@ def retrieve_from_graph(self, query_bundle: QueryBundle) -> List[NodeWithScore]: ) matches = self._parse_llm_output(response) - return self._prepare_matches(matches) + return self._prepare_matches(matches, limit=limit or self._limit) async def aretrieve_from_graph( - self, query_bundle: QueryBundle + self, query_bundle: QueryBundle, limit: Optional[int] = None ) -> List[NodeWithScore]: response = await self._llm.apredict( self._synonym_prompt, @@ -126,4 +136,4 @@ async def aretrieve_from_graph( ) matches = self._parse_llm_output(response) - return await self._aprepare_matches(matches) + return await self._aprepare_matches(matches, limit=limit or self._limit) diff --git a/llama-index-core/llama_index/core/indices/property_graph/sub_retrievers/vector.py b/llama-index-core/llama_index/core/indices/property_graph/sub_retrievers/vector.py index effbcd7e5471e..93611fd3b14be 100644 --- a/llama-index-core/llama_index/core/indices/property_graph/sub_retrievers/vector.py +++ b/llama-index-core/llama_index/core/indices/property_graph/sub_retrievers/vector.py @@ -49,6 +49,7 @@ def __init__( vector_store: Optional[BasePydanticVectorStore] = None, similarity_top_k: int = 4, path_depth: int = 1, + limit: int = 30, similarity_score: Optional[float] = None, filters: Optional[MetadataFilters] = None, **kwargs: Any, @@ -58,6 +59,7 @@ def __init__( self._similarity_top_k = similarity_top_k self._vector_store = vector_store self._path_depth = path_depth + self._limit = limit self._similarity_score = similarity_score self._filters = filters @@ -112,7 +114,9 @@ async def _aget_vector_store_query( **self._retriever_kwargs, ) - def retrieve_from_graph(self, query_bundle: QueryBundle) -> List[NodeWithScore]: + def retrieve_from_graph( + self, query_bundle: QueryBundle, limit: Optional[int] = None + ) -> List[NodeWithScore]: vector_store_query = self._get_vector_store_query(query_bundle) triplets = [] @@ -126,7 +130,10 @@ def retrieve_from_graph(self, query_bundle: QueryBundle) -> List[NodeWithScore]: kg_ids = [node.id for node in kg_nodes] triplets = self._graph_store.get_rel_map( - kg_nodes, depth=self._path_depth, ignore_rels=[KG_SOURCE_REL] + kg_nodes, + depth=self._path_depth, + limit=limit or self._limit, + ignore_rels=[KG_SOURCE_REL], ) elif self._vector_store is not None: @@ -136,7 +143,10 @@ def retrieve_from_graph(self, query_bundle: QueryBundle) -> List[NodeWithScore]: scores = query_result.similarities kg_nodes = self._graph_store.get(ids=kg_ids) triplets = self._graph_store.get_rel_map( - kg_nodes, depth=self._path_depth, ignore_rels=[KG_SOURCE_REL] + kg_nodes, + depth=self._path_depth, + limit=limit or self._limit, + ignore_rels=[KG_SOURCE_REL], ) elif query_result.ids is not None and query_result.similarities is not None: @@ -144,7 +154,10 @@ def retrieve_from_graph(self, query_bundle: QueryBundle) -> List[NodeWithScore]: scores = query_result.similarities kg_nodes = self._graph_store.get(ids=kg_ids) triplets = self._graph_store.get_rel_map( - kg_nodes, depth=self._path_depth, ignore_rels=[KG_SOURCE_REL] + kg_nodes, + depth=self._path_depth, + limit=limit or self._limit, + ignore_rels=[KG_SOURCE_REL], ) for triplet in triplets: @@ -174,7 +187,7 @@ def retrieve_from_graph(self, query_bundle: QueryBundle) -> List[NodeWithScore]: return self._get_nodes_with_score([x[0] for x in top_k], [x[1] for x in top_k]) async def aretrieve_from_graph( - self, query_bundle: QueryBundle + self, query_bundle: QueryBundle, limit: Optional[int] = None ) -> List[NodeWithScore]: vector_store_query = await self._aget_vector_store_query(query_bundle) @@ -189,7 +202,10 @@ async def aretrieve_from_graph( kg_nodes, scores = result kg_ids = [node.id for node in kg_nodes] triplets = await self._graph_store.aget_rel_map( - kg_nodes, depth=self._path_depth, ignore_rels=[KG_SOURCE_REL] + kg_nodes, + depth=self._path_depth, + limit=limit or self._limit, + ignore_rels=[KG_SOURCE_REL], ) elif self._vector_store is not None: @@ -199,7 +215,10 @@ async def aretrieve_from_graph( scores = query_result.similarities kg_nodes = await self._graph_store.aget(ids=kg_ids) triplets = await self._graph_store.aget_rel_map( - kg_nodes, depth=self._path_depth, ignore_rels=[KG_SOURCE_REL] + kg_nodes, + depth=self._path_depth, + limit=limit or self._limit, + ignore_rels=[KG_SOURCE_REL], ) elif query_result.ids is not None and query_result.similarities is not None: @@ -207,7 +226,10 @@ async def aretrieve_from_graph( scores = query_result.similarities kg_nodes = await self._graph_store.aget(ids=kg_ids) triplets = await self._graph_store.aget_rel_map( - kg_nodes, depth=self._path_depth, ignore_rels=[KG_SOURCE_REL] + kg_nodes, + depth=self._path_depth, + limit=limit or self._limit, + ignore_rels=[KG_SOURCE_REL], ) for triplet in triplets: