Skip to content

Commit

Permalink
use MultiModalSynthesizer as default
Browse files Browse the repository at this point in the history
  • Loading branch information
leehuwuj committed Nov 29, 2024
1 parent 81a8bfe commit 2e8e89a
Showing 1 changed file with 64 additions and 62 deletions.
126 changes: 64 additions & 62 deletions templates/components/engines/python/agent/tools/query_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@
from llama_index.core.prompts.default_prompt_selectors import (
DEFAULT_TEXT_QA_PROMPT_SEL,
)
from llama_index.core.query_engine import (
RetrieverQueryEngine,
)
from llama_index.core.query_engine.multi_modal import _get_image_and_text_nodes
from llama_index.core.response_synthesizers.base import BaseSynthesizer, QueryTextType
from llama_index.core.schema import (
Expand All @@ -24,6 +21,64 @@
from app.settings import get_multi_modal_llm


def create_query_engine(index, **kwargs) -> BaseQueryEngine:
"""
Create a query engine for the given index.
Args:
index: The index to create a query engine for.
params (optional): Additional parameters for the query engine, e.g: similarity_top_k
"""

top_k = int(os.getenv("TOP_K", 0))
if top_k != 0 and kwargs.get("filters") is None:
kwargs["similarity_top_k"] = top_k
multimodal_llm = get_multi_modal_llm()
if multimodal_llm:
kwargs["response_synthesizer"] = MultiModalSynthesizer(
multimodal_model=multimodal_llm,
response_synthesizer=get_response_synthesizer(),
)

# If index is index is LlamaCloudIndex
# use auto_routed mode for better query results
if index.__class__.__name__ == "LlamaCloudIndex":
retrieval_mode = kwargs.get("retrieval_mode")
if retrieval_mode is None:
kwargs["retrieval_mode"] = "auto_routed"
if multimodal_llm:
kwargs["retrieve_image_nodes"] = True
return index.as_query_engine(**kwargs)


def get_query_engine_tool(
index,
name: Optional[str] = None,
description: Optional[str] = None,
**kwargs,
) -> QueryEngineTool:
"""
Get a query engine tool for the given index.
Args:
index: The index to create a query engine for.
name (optional): The name of the tool.
description (optional): The description of the tool.
"""
if name is None:
name = "query_index"
if description is None:
description = (
"Use this tool to retrieve information about the text corpus from an index."
)
query_engine = create_query_engine(index, **kwargs)
return QueryEngineTool.from_defaults(
query_engine=query_engine,
name=name,
description=description,
)


class MultiModalSynthesizer(BaseSynthesizer):
"""
A synthesizer that summarizes text nodes and uses a multi-modal LLM to generate a response.
Expand Down Expand Up @@ -70,6 +125,9 @@ async def asynthesize(
) -> RESPONSE_TYPE:
image_nodes, text_nodes = _get_image_and_text_nodes(nodes)

if len(image_nodes) == 0:
return await self._response_synthesizer.asynthesize(query, text_nodes)

# Summarize the text nodes to avoid exceeding the token limit
text_response = str(
await self._response_synthesizer.asynthesize(query, text_nodes)
Expand Down Expand Up @@ -104,6 +162,9 @@ def synthesize(
) -> RESPONSE_TYPE:
image_nodes, text_nodes = _get_image_and_text_nodes(nodes)

if len(image_nodes) == 0:
return self._response_synthesizer.synthesize(query, text_nodes)

# Summarize the text nodes to avoid exceeding the token limit
text_response = str(self._response_synthesizer.synthesize(query, text_nodes))

Expand All @@ -126,62 +187,3 @@ def synthesize(
source_nodes=nodes,
metadata={"text_nodes": text_nodes, "image_nodes": image_nodes},
)


def create_query_engine(index, **kwargs) -> BaseQueryEngine:
"""
Create a query engine for the given index.
Args:
index: The index to create a query engine for.
params (optional): Additional parameters for the query engine, e.g: similarity_top_k
"""

top_k = int(os.getenv("TOP_K", 0))
if top_k != 0 and kwargs.get("filters") is None:
kwargs["similarity_top_k"] = top_k
# If index is index is LlamaCloudIndex
# use auto_routed mode for better query results
if index.__class__.__name__ == "LlamaCloudIndex":
retrieval_mode = kwargs.get("retrieval_mode")
if retrieval_mode is None:
kwargs["retrieval_mode"] = "auto_routed"
if get_multi_modal_llm():
kwargs["retrieve_image_nodes"] = True
return RetrieverQueryEngine(
retriever=index.as_retriever(**kwargs),
response_synthesizer=MultiModalSynthesizer(
multimodal_model=get_multi_modal_llm(),
response_synthesizer=get_response_synthesizer(),
),
)

return index.as_query_engine(**kwargs)


def get_query_engine_tool(
index,
name: Optional[str] = None,
description: Optional[str] = None,
**kwargs,
) -> QueryEngineTool:
"""
Get a query engine tool for the given index.
Args:
index: The index to create a query engine for.
name (optional): The name of the tool.
description (optional): The description of the tool.
"""
if name is None:
name = "query_index"
if description is None:
description = (
"Use this tool to retrieve information about the text corpus from an index."
)
query_engine = create_query_engine(index, **kwargs)
return QueryEngineTool.from_defaults(
query_engine=query_engine,
name=name,
description=description,
)

0 comments on commit 2e8e89a

Please sign in to comment.