Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Advanced text to sql sample rows #17479

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
NLStructStoreQueryEngine,
SQLStructStoreQueryEngine,
SQLTableRetrieverQueryEngine,
NLSQLRetrieverWithSampleRows,
NLSQLTableQueryEngineWithSampleRows
)

__all__ = [
Expand All @@ -33,4 +35,6 @@
"GPTSQLStructStoreQueryEngine",
"SQLTableRetrieverQueryEngine",
"NLSQLTableQueryEngine",
"NLSQLRetrieverWithSampleRows",
"NLSQLTableQueryEngineWithSampleRows"
]
175 changes: 173 additions & 2 deletions llama-index-core/llama_index/core/indices/struct_store/sql_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Tuple, Union, cast

from llama_index.core.indices.vector_store import VectorStoreIndex
from llama_index.core.base.base_query_engine import BaseQueryEngine
from llama_index.core.base.response.schema import (
RESPONSE_TYPE,
Expand Down Expand Up @@ -35,7 +35,7 @@
from llama_index.core.response_synthesizers import (
get_response_synthesizer,
)
from llama_index.core.schema import NodeWithScore, QueryBundle
from llama_index.core.schema import NodeWithScore, QueryBundle, TextNode
from llama_index.core.settings import Settings
from llama_index.core.utilities.sql_wrapper import SQLDatabase
from sqlalchemy import Table
Expand Down Expand Up @@ -627,6 +627,177 @@ def sql_retriever(self) -> NLSQLRetriever:
return self._sql_retriever


class NLSQLRetrieverWithSampleRows(NLSQLRetriever):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ooc, does this actually require a brand new class? We can't just use the existing class and add an arg for the retriever?

Copy link
Contributor Author

@osamadel osamadel Jan 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, in a nutshell, we can, but the issue is that both NLSQLTableQueryEngine and SQLTableRetrieverQueryEngine use the NLSQLRetriever with its defaults arguments.
For example:

self._sql_retriever = NLSQLRetriever(
            sql_database,
            llm=llm,
            text_to_sql_prompt=text_to_sql_prompt,
            context_query_kwargs=context_query_kwargs,
            tables=tables,
            sql_parser_mode=SQLParserMode.PGVECTOR,
            context_str_prefix=context_str_prefix,
            sql_only=sql_only,
            callback_manager=callback_manager,

So in order to make the distinction clear, I opted for a separate class.
If we are to go with the argument option, we'll probably still need the query engine class that internally instantiate a NLSQLRetriever with the sample rows argument used.

What do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cant we just update the query engine to also pass down the row retriever?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I did that, please review and let me know if something needs fixing.

"""Text-to-SQL Retriever.

Retrieves via text. Retrieves sample rows with each retrieved table schema.
Overwrites the `_get_table_context` method in the parent class `NLSQLRetriever`
to include sample rows in the context.

Args:
sql_database (SQLDatabase): SQL database.
text_to_sql_prompt (BasePromptTemplate): Prompt template for text-to-sql.
Defaults to DEFAULT_TEXT_TO_SQL_PROMPT.
context_query_kwargs (dict): Mapping from table name to context query.
Defaults to None.
tables (Union[List[str], List[Table]]): List of table names or Table objects.
table_retriever (ObjectRetriever[SQLTableSchema]): Object retriever for
SQLTableSchema objects. Defaults to None.
context_str_prefix (str): Prefix for context string. Defaults to None.
return_raw (bool): Whether to return plain-text dump of SQL results, or parsed into Nodes.
handle_sql_errors (bool): Whether to handle SQL errors. Defaults to True.
sql_only (bool) : Whether to get only sql and not the sql query result.
Default to False.
llm (Optional[LLM]): Language model to use.
similarity_top_k (int): how many rows to retrieve for each table

"""
def __init__(
self,
sql_database: SQLDatabase,
text_to_sql_prompt: Optional[BasePromptTemplate] = None,
context_query_kwargs: Optional[dict] = None,
tables: Optional[Union[List[str], List[Table]]] = None,
table_retriever: Optional[ObjectRetriever[SQLTableSchema]] = None,
rows_retrievers: Optional[dict[str, VectorStoreIndex]] = None,
context_str_prefix: Optional[str] = None,
sql_parser_mode: SQLParserMode = SQLParserMode.DEFAULT,
llm: Optional[LLM] = None,
embed_model: Optional[BaseEmbedding] = None,
return_raw: bool = True,
handle_sql_errors: bool = True,
sql_only: bool = False,
callback_manager: Optional[CallbackManager] = None,
verbose: bool = False,
similarity_top_k: int = 5,
**kwargs: Any,
) -> None:
"""Initialize params."""
super().__init__(
sql_database,
text_to_sql_prompt,
context_query_kwargs,
tables,
table_retriever,
context_str_prefix,
sql_parser_mode,
llm,
embed_model,
return_raw,
handle_sql_errors,
sql_only,
callback_manager,
verbose,
**kwargs,
)
self._rows_retrievers = rows_retrievers
self._similarity_top_k = similarity_top_k


def _get_table_context(
self,
query_bundle: QueryBundle
):
"""Get table context string."""
table_schema_objs = self._get_tables(query_bundle.query_str)
context_strs = []
for table_schema_obj in table_schema_objs:
# first append table info + additional context
table_info = self._sql_database.get_single_table_info(
table_schema_obj.table_name
)
if table_schema_obj.context_str:
table_opt_context = " The table description is: "
table_opt_context += table_schema_obj.context_str
table_info += table_opt_context

# also lookup vector index to return relevant table rows
if self._rows_retrievers is not None:
rows_retriever = self._rows_retrievers[
table_schema_obj.table_name
].as_retriever(similarity_top_k=self._similarity_top_k)
relevant_nodes = rows_retriever.retrieve(query_bundle.query_str)
else:
# Retrieve the top `similarity_top_k` rows from each table and add them to the context
relevant_nodes = [TextNode(text=str(t)) for t in self._sql_database.run_sql(
f"SELECT TOP {self._similarity_top_k} * "
f"from [{self._sql_database._schema}].[{table_schema_obj.table_name}]")]
if len(relevant_nodes) > 0:
table_row_context = "\nHere are some relevant example rows (values in the same order as columns above)\n"
for node in relevant_nodes:
table_row_context += str(node.get_content()) + "\n"
table_info += table_row_context

if self._verbose:
print(f"> Table Info: {table_info}")
context_strs.append(table_info)

return "\n\n".join(context_strs)


class NLSQLTableQueryEngineWithSampleRows(BaseSQLTableQueryEngine):
"""
Advanced natural language SQL Table query engine. It uses NLSQLRetrieverWithSampleRows
to add sample rows to the schema and table's description in the context
to the BaseSQLTableQueryEngine to generate better SQL.

Read NLStructStoreQueryEngine's docstring for more info on NL SQL.

NOTE: Any Text-to-SQL application should be aware that executing
arbitrary SQL queries can be a security risk. It is recommended to
take precautions as needed, such as using restricted roles, read-only
databases, sandboxing, etc.
"""

def __init__(
self,
sql_database: SQLDatabase,
rows_retrievers: Optional[dict[str, VectorStoreIndex]] = None,
table_retriever: Optional[ObjectRetriever[SQLTableSchema]] = None,
llm: Optional[LLM] = None,
text_to_sql_prompt: Optional[BasePromptTemplate] = None,
context_query_kwargs: Optional[dict] = None,
synthesize_response: bool = True,
markdown_response: bool = False,
response_synthesis_prompt: Optional[BasePromptTemplate] = None,
refine_synthesis_prompt: Optional[BasePromptTemplate] = None,
tables: Optional[Union[List[str], List[Table]]] = None,
context_str_prefix: Optional[str] = None,
embed_model: Optional[BaseEmbedding] = None,
sql_only: bool = False,
callback_manager: Optional[CallbackManager] = None,
verbose: bool = False,
similarity_top_k: int = 5,
**kwargs: Any
) -> None:
"""Initialize params."""
self._sql_retriever = NLSQLRetrieverWithSampleRows(
sql_database,
rows_retrievers=rows_retrievers,
table_retriever=table_retriever,
llm=llm,
text_to_sql_prompt=text_to_sql_prompt,
context_query_kwargs=context_query_kwargs,
tables=tables,
context_str_prefix=context_str_prefix,
embed_model=embed_model,
sql_only=sql_only,
callback_manager=callback_manager,
similarity_top_k=similarity_top_k,
verbose=verbose,
)
super().__init__(
synthesize_response=synthesize_response,
markdown_response=markdown_response,
response_synthesis_prompt=response_synthesis_prompt,
refine_synthesis_prompt=refine_synthesis_prompt,
llm=llm,
callback_manager=callback_manager,
verbose=verbose,
**kwargs,
)


# legacy
GPTNLStructStoreQueryEngine = NLStructStoreQueryEngine
GPTSQLStructStoreQueryEngine = SQLStructStoreQueryEngine
Loading