-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Add RetrySqlQueryCreatorTool for handling failed SQL query generation #15
base: main
Are you sure you want to change the base?
Changes from 1 commit
c8ad59b
e308c78
852d2e7
3147269
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -14,6 +14,7 @@ | |||||||||||||||
from langchain_core.tools import StateTool | ||||||||||||||||
import re | ||||||||||||||||
|
||||||||||||||||
ERROR = "" | ||||||||||||||||
class BaseSQLDatabaseTool(BaseModel): | ||||||||||||||||
"""Base tool for interacting with a SQL database.""" | ||||||||||||||||
|
||||||||||||||||
|
@@ -43,7 +44,7 @@ class Config(StateTool.Config): | |||||||||||||||
description: str = """ | ||||||||||||||||
Input to this tool is a detailed and correct SQL query, output is a result from the database. | ||||||||||||||||
If the query is not correct, an error message will be returned. | ||||||||||||||||
If an error is returned, re-run the sql_db_query_creator tool to get the correct query. | ||||||||||||||||
If an error is returned, re-run the retry_sql_db_query_creator tool to get the correct query. | ||||||||||||||||
""" | ||||||||||||||||
|
||||||||||||||||
def __init__(__pydantic_self__, **data: Any) -> None: | ||||||||||||||||
|
@@ -65,6 +66,7 @@ def _run( | |||||||||||||||
) | ||||||||||||||||
executable_query = executable_query.strip('\"') | ||||||||||||||||
executable_query = re.sub('\\n```', '',executable_query) | ||||||||||||||||
self.db.run_no_throw(executable_query) | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. issue: Duplicate call to The method |
||||||||||||||||
return self.db.run_no_throw(executable_query) | ||||||||||||||||
|
||||||||||||||||
async def _arun( | ||||||||||||||||
|
@@ -75,14 +77,98 @@ async def _arun( | |||||||||||||||
raise NotImplementedError("QuerySparkSQLDataBaseTool does not support async") | ||||||||||||||||
|
||||||||||||||||
def _extract_sql_query(self): | ||||||||||||||||
for value in self.state: | ||||||||||||||||
for value in reversed(self.state): | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. question (bug_risk): Reversing the state list might have unintended consequences. Reversing the state list could lead to unexpected behavior if the order of states is important. Ensure that this change is intentional and won't cause issues. |
||||||||||||||||
for key, input_string in value.items(): | ||||||||||||||||
if "sql_db_query_creator" in key: | ||||||||||||||||
if "tool='retry_sql_db_query_creator'" in key: | ||||||||||||||||
return input_string | ||||||||||||||||
elif "tool='sql_db_query_creator'" in key: | ||||||||||||||||
Comment on lines
+84
to
+86
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion (code-quality): We've found these issues:
Suggested change
|
||||||||||||||||
return input_string | ||||||||||||||||
return None | ||||||||||||||||
|
||||||||||||||||
|
||||||||||||||||
|
||||||||||||||||
class RetrySqlQueryCreatorTool(StateTool): | ||||||||||||||||
"""Tool for re-creating SQL query.Use this to retry creation of sql query.""" | ||||||||||||||||
|
||||||||||||||||
name = "retry_sql_db_query_creator" | ||||||||||||||||
description = """ | ||||||||||||||||
This is a tool used to re-create sql query for user input based on the incorrect query generated and error returned from sql_db_query tool. | ||||||||||||||||
Input to this tool is user prompt, incorrect sql query and error message | ||||||||||||||||
Output is a sql query | ||||||||||||||||
After running this tool, you can run sql_db_query tool to get the result | ||||||||||||||||
""" | ||||||||||||||||
sqlcreatorllm: BaseLanguageModel = Field(exclude=True) | ||||||||||||||||
|
||||||||||||||||
|
||||||||||||||||
class Config(StateTool.Config): | ||||||||||||||||
"""Configuration for this pydantic object.""" | ||||||||||||||||
|
||||||||||||||||
arbitrary_types_allowed = True | ||||||||||||||||
extra = Extra.allow | ||||||||||||||||
|
||||||||||||||||
def __init__(__pydantic_self__, **data: Any) -> None: | ||||||||||||||||
"""Initialize the tool.""" | ||||||||||||||||
super().__init__(**data) | ||||||||||||||||
|
||||||||||||||||
def _run( | ||||||||||||||||
self, | ||||||||||||||||
user_input: str, | ||||||||||||||||
run_manager: Optional[CallbackManagerForToolRun] = None, | ||||||||||||||||
) -> str: | ||||||||||||||||
"""Get the SQL query for the incorrect query.""" | ||||||||||||||||
return self._create_sql_query(user_input) | ||||||||||||||||
|
||||||||||||||||
async def _arun( | ||||||||||||||||
self, | ||||||||||||||||
table_name: str, | ||||||||||||||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None, | ||||||||||||||||
) -> str: | ||||||||||||||||
raise NotImplementedError("SqlQueryCreatorTool does not support async") | ||||||||||||||||
|
||||||||||||||||
def _create_sql_query(self,user_input): | ||||||||||||||||
|
||||||||||||||||
sql_query = self._extract_sql_query() | ||||||||||||||||
error_message = self._extract_error_message() | ||||||||||||||||
if sql_query is None: | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion (bug_risk): Consider logging when sql_query is None. It might be useful to log a message when
Suggested change
|
||||||||||||||||
return "This tool is not meant to be run directly. Start with a SQLQueryCreatorTool" | ||||||||||||||||
|
||||||||||||||||
prompt_input = PromptTemplate( | ||||||||||||||||
input_variables=["user_input","sql_query", "error_message"], | ||||||||||||||||
template=SQL_QUERY_CREATOR_RETRY | ||||||||||||||||
) | ||||||||||||||||
query_creator_chain = LLMChain(llm=self.sqlcreatorllm, prompt=prompt_input) | ||||||||||||||||
|
||||||||||||||||
sql_query = query_creator_chain.run( | ||||||||||||||||
( | ||||||||||||||||
{ | ||||||||||||||||
"sql_query": sql_query, | ||||||||||||||||
"error_message": error_message, | ||||||||||||||||
"user_input": user_input | ||||||||||||||||
} | ||||||||||||||||
) | ||||||||||||||||
) | ||||||||||||||||
sql_query = sql_query.replace("```","") | ||||||||||||||||
sql_query = sql_query.replace("sql","") | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. issue (bug_risk): Removing 'sql' from the query might cause issues. The line |
||||||||||||||||
|
||||||||||||||||
return sql_query | ||||||||||||||||
|
||||||||||||||||
def _extract_sql_query(self): | ||||||||||||||||
for value in reversed(self.state): | ||||||||||||||||
for key, input_string in value.items(): | ||||||||||||||||
if "tool='retry_sql_db_query_creator'" in key: | ||||||||||||||||
return input_string | ||||||||||||||||
elif "tool='sql_db_query_creator'" in key: | ||||||||||||||||
Comment on lines
+160
to
+162
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion (code-quality): We've found these issues:
Suggested change
|
||||||||||||||||
return input_string | ||||||||||||||||
return None | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion: Consider raising an exception instead of returning None. Returning
Suggested change
|
||||||||||||||||
|
||||||||||||||||
def _extract_error_message(self): | ||||||||||||||||
for value in reversed(self.state): | ||||||||||||||||
for key, input_string in value.items(): | ||||||||||||||||
if "tool='sql_db_query'" in key: | ||||||||||||||||
if "Error" in input_string: | ||||||||||||||||
return input_string | ||||||||||||||||
Comment on lines
+169
to
+171
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion (code-quality): Merge nested if conditions (
Suggested change
ExplanationToo much nesting can make code difficult to understand, and this is especiallytrue in Python, where there are no brackets to help out with the delineation of different nesting levels. Reading deeply nested code is confusing, since you have to keep track of which |
||||||||||||||||
return None | ||||||||||||||||
|
||||||||||||||||
class SqlQueryCreatorTool(StateTool): | ||||||||||||||||
"""Tool for creating SQL query.Use this to create sql query.""" | ||||||||||||||||
|
||||||||||||||||
|
@@ -147,43 +233,24 @@ def _parse_data_model_context(self): | |||||||||||||||
def _create_sql_query(self,user_input): | ||||||||||||||||
|
||||||||||||||||
few_shot_examples = self._parse_few_shot_examples() | ||||||||||||||||
sql_query = self._extract_sql_query() | ||||||||||||||||
db_schema = self._parse_db_schema() | ||||||||||||||||
data_model_context = self._parse_data_model_context() | ||||||||||||||||
if sql_query is None: | ||||||||||||||||
prompt_input = PromptTemplate( | ||||||||||||||||
input_variables=["db_schema", "user_input", "few_shot_examples","data_model_context"], | ||||||||||||||||
template=self.SQL_QUERY_CREATOR_TEMPLATE, | ||||||||||||||||
) | ||||||||||||||||
query_creator_chain = LLMChain(llm=self.sqlcreatorllm, prompt=prompt_input) | ||||||||||||||||
|
||||||||||||||||
sql_query = query_creator_chain.run( | ||||||||||||||||
( | ||||||||||||||||
{ | ||||||||||||||||
"db_schema": db_schema, | ||||||||||||||||
"user_input": user_input, | ||||||||||||||||
"few_shot_examples": few_shot_examples, | ||||||||||||||||
"data_model_context": data_model_context | ||||||||||||||||
} | ||||||||||||||||
) | ||||||||||||||||
) | ||||||||||||||||
else: | ||||||||||||||||
prompt_input = PromptTemplate( | ||||||||||||||||
input_variables=["db_schema", "user_input", "few_shot_examples","data_model_context"], | ||||||||||||||||
template=SQL_QUERY_CREATOR_RETRY | ||||||||||||||||
) | ||||||||||||||||
query_creator_chain = LLMChain(llm=self.sqlcreatorllm, prompt=prompt_input) | ||||||||||||||||
|
||||||||||||||||
sql_query = query_creator_chain.run( | ||||||||||||||||
( | ||||||||||||||||
{ | ||||||||||||||||
"db_schema": db_schema, | ||||||||||||||||
"user_input": user_input, | ||||||||||||||||
"few_shot_examples": few_shot_examples, | ||||||||||||||||
"data_model_context": data_model_context | ||||||||||||||||
} | ||||||||||||||||
) | ||||||||||||||||
prompt_input = PromptTemplate( | ||||||||||||||||
input_variables=["db_schema", "user_input", "few_shot_examples","data_model_context"], | ||||||||||||||||
template=self.SQL_QUERY_CREATOR_TEMPLATE, | ||||||||||||||||
) | ||||||||||||||||
query_creator_chain = LLMChain(llm=self.sqlcreatorllm, prompt=prompt_input) | ||||||||||||||||
|
||||||||||||||||
sql_query = query_creator_chain.run( | ||||||||||||||||
( | ||||||||||||||||
{ | ||||||||||||||||
"db_schema": db_schema, | ||||||||||||||||
"user_input": user_input, | ||||||||||||||||
"few_shot_examples": few_shot_examples, | ||||||||||||||||
"data_model_context": data_model_context | ||||||||||||||||
} | ||||||||||||||||
) | ||||||||||||||||
) | ||||||||||||||||
sql_query = sql_query.replace("```","") | ||||||||||||||||
sql_query = sql_query.replace("sql","") | ||||||||||||||||
|
||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,20 @@ | ||
|
||
|
||
SQL_QUERY_CREATOR_RETRY = """ | ||
You have failed in the first attempt to generate correct sql query. Please try again to rewrite correct sql query. | ||
""" | ||
Your task is convert an incorrect query resulting from user question to a correct query which is databricks sql compatible. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nitpick (typo): Typo in the prompt template. The sentence should be 'Your task is to convert an incorrect query resulting from a user question to a correct query which is Databricks SQL compatible.' |
||
Adhere to these rules: | ||
- **Deliberately go through the question and database schema word by word** to appropriately answer the question | ||
- **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`. | ||
- When creating a ratio, always cast the numerator as float | ||
|
||
### Task: | ||
Generate a correct SQL query that answers the question [QUESTION]`{user_input}`[/QUESTION]. | ||
The query you will correct is: {sql_query} | ||
The error message is: {error_message} | ||
|
||
### Response: | ||
Based on your instructions, here is the SQL query I have generated | ||
[SQL]""" | ||
|
||
SQL_QUERY_CREATOR_7b = """ | ||
### Instructions: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion: Consider removing the unused ERROR variable.
The variable
ERROR
is defined but never used in the code. If it's not needed, it would be better to remove it to keep the code clean.