Skip to content

Commit

Permalink
Ky/dynamic pg triplet retrieval limit (run-llama#16928)
Browse files Browse the repository at this point in the history
  • Loading branch information
kevin-yang-racap authored Nov 12, 2024
1 parent f60869a commit d322a32
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
] = DEFAULT_SYNONYM_EXPAND_TEMPLATE,
max_keywords: int = 10,
path_depth: int = 1,
limit: int = 30,
output_parsing_fn: Optional[Callable] = None,
llm: Optional[LLM] = None,
**kwargs: Any,
Expand All @@ -70,6 +71,7 @@ def __init__(
self._output_parsing_fn = output_parsing_fn
self._max_keywords = max_keywords
self._path_depth = path_depth
self._limit = limit
super().__init__(
graph_store=graph_store,
include_text=include_text,
Expand All @@ -86,38 +88,46 @@ def _parse_llm_output(self, output: str) -> List[str]:
# capitalize to normalize with ingestion
return [x.strip().capitalize() for x in matches if x.strip()]

def _prepare_matches(self, matches: List[str]) -> List[NodeWithScore]:
def _prepare_matches(
self, matches: List[str], limit: Optional[int] = None
) -> List[NodeWithScore]:
kg_nodes = self._graph_store.get(ids=matches)
triplets = self._graph_store.get_rel_map(
kg_nodes,
depth=self._path_depth,
limit=limit or self._limit,
ignore_rels=[KG_SOURCE_REL],
)

return self._get_nodes_with_score(triplets)

async def _aprepare_matches(self, matches: List[str]) -> List[NodeWithScore]:
async def _aprepare_matches(
self, matches: List[str], limit: Optional[int] = None
) -> List[NodeWithScore]:
kg_nodes = await self._graph_store.aget(ids=matches)
triplets = await self._graph_store.aget_rel_map(
kg_nodes,
depth=self._path_depth,
limit=limit or self._limit,
ignore_rels=[KG_SOURCE_REL],
)

return self._get_nodes_with_score(triplets)

def retrieve_from_graph(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
def retrieve_from_graph(
self, query_bundle: QueryBundle, limit: Optional[int] = None
) -> List[NodeWithScore]:
response = self._llm.predict(
self._synonym_prompt,
query_str=query_bundle.query_str,
max_keywords=self._max_keywords,
)
matches = self._parse_llm_output(response)

return self._prepare_matches(matches)
return self._prepare_matches(matches, limit=limit or self._limit)

async def aretrieve_from_graph(
self, query_bundle: QueryBundle
self, query_bundle: QueryBundle, limit: Optional[int] = None
) -> List[NodeWithScore]:
response = await self._llm.apredict(
self._synonym_prompt,
Expand All @@ -126,4 +136,4 @@ async def aretrieve_from_graph(
)
matches = self._parse_llm_output(response)

return await self._aprepare_matches(matches)
return await self._aprepare_matches(matches, limit=limit or self._limit)
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
vector_store: Optional[BasePydanticVectorStore] = None,
similarity_top_k: int = 4,
path_depth: int = 1,
limit: int = 30,
similarity_score: Optional[float] = None,
filters: Optional[MetadataFilters] = None,
**kwargs: Any,
Expand All @@ -58,6 +59,7 @@ def __init__(
self._similarity_top_k = similarity_top_k
self._vector_store = vector_store
self._path_depth = path_depth
self._limit = limit
self._similarity_score = similarity_score
self._filters = filters

Expand Down Expand Up @@ -112,7 +114,9 @@ async def _aget_vector_store_query(
**self._retriever_kwargs,
)

def retrieve_from_graph(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
def retrieve_from_graph(
self, query_bundle: QueryBundle, limit: Optional[int] = None
) -> List[NodeWithScore]:
vector_store_query = self._get_vector_store_query(query_bundle)

triplets = []
Expand All @@ -126,7 +130,10 @@ def retrieve_from_graph(self, query_bundle: QueryBundle) -> List[NodeWithScore]:

kg_ids = [node.id for node in kg_nodes]
triplets = self._graph_store.get_rel_map(
kg_nodes, depth=self._path_depth, ignore_rels=[KG_SOURCE_REL]
kg_nodes,
depth=self._path_depth,
limit=limit or self._limit,
ignore_rels=[KG_SOURCE_REL],
)

elif self._vector_store is not None:
Expand All @@ -136,15 +143,21 @@ def retrieve_from_graph(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
scores = query_result.similarities
kg_nodes = self._graph_store.get(ids=kg_ids)
triplets = self._graph_store.get_rel_map(
kg_nodes, depth=self._path_depth, ignore_rels=[KG_SOURCE_REL]
kg_nodes,
depth=self._path_depth,
limit=limit or self._limit,
ignore_rels=[KG_SOURCE_REL],
)

elif query_result.ids is not None and query_result.similarities is not None:
kg_ids = query_result.ids
scores = query_result.similarities
kg_nodes = self._graph_store.get(ids=kg_ids)
triplets = self._graph_store.get_rel_map(
kg_nodes, depth=self._path_depth, ignore_rels=[KG_SOURCE_REL]
kg_nodes,
depth=self._path_depth,
limit=limit or self._limit,
ignore_rels=[KG_SOURCE_REL],
)

for triplet in triplets:
Expand Down Expand Up @@ -174,7 +187,7 @@ def retrieve_from_graph(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
return self._get_nodes_with_score([x[0] for x in top_k], [x[1] for x in top_k])

async def aretrieve_from_graph(
self, query_bundle: QueryBundle
self, query_bundle: QueryBundle, limit: Optional[int] = None
) -> List[NodeWithScore]:
vector_store_query = await self._aget_vector_store_query(query_bundle)

Expand All @@ -189,7 +202,10 @@ async def aretrieve_from_graph(
kg_nodes, scores = result
kg_ids = [node.id for node in kg_nodes]
triplets = await self._graph_store.aget_rel_map(
kg_nodes, depth=self._path_depth, ignore_rels=[KG_SOURCE_REL]
kg_nodes,
depth=self._path_depth,
limit=limit or self._limit,
ignore_rels=[KG_SOURCE_REL],
)

elif self._vector_store is not None:
Expand All @@ -199,15 +215,21 @@ async def aretrieve_from_graph(
scores = query_result.similarities
kg_nodes = await self._graph_store.aget(ids=kg_ids)
triplets = await self._graph_store.aget_rel_map(
kg_nodes, depth=self._path_depth, ignore_rels=[KG_SOURCE_REL]
kg_nodes,
depth=self._path_depth,
limit=limit or self._limit,
ignore_rels=[KG_SOURCE_REL],
)

elif query_result.ids is not None and query_result.similarities is not None:
kg_ids = query_result.ids
scores = query_result.similarities
kg_nodes = await self._graph_store.aget(ids=kg_ids)
triplets = await self._graph_store.aget_rel_map(
kg_nodes, depth=self._path_depth, ignore_rels=[KG_SOURCE_REL]
kg_nodes,
depth=self._path_depth,
limit=limit or self._limit,
ignore_rels=[KG_SOURCE_REL],
)

for triplet in triplets:
Expand Down

0 comments on commit d322a32

Please sign in to comment.