-
Notifications
You must be signed in to change notification settings - Fork 108
/
Copy pathdbcopilot.py
115 lines (88 loc) · 4.56 KB
/
dbcopilot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
from dotenv import load_dotenv
import streamlit as st
from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI
from langchain.tools import BaseTool, Tool, tool
from langchain.callbacks.base import BaseCallbackHandler
from langchain import PromptTemplate
import pandas as pd
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.sql_database import SQLDatabase
from langchain.agents import AgentExecutor
from langchain.agents.agent_types import AgentType
from langchain.callbacks import StreamlitCallbackHandler
st.set_page_config(page_title="DBCopilot", page_icon="📊")
st.header('📊 Welcome to DBCopilot, your copilot for structured databases.')
load_dotenv()
#os.environ["HUGGINGFACEHUB_API_TOKEN"]
openai_api_key = os.environ['OPENAI_API_KEY']
db = SQLDatabase.from_uri('sqlite:///chinook.db')
# Import Azure OpenAI
#from langchain.llms import AzureOpenAI
#from langchain.chat_models import AzureChatOpenAI
# Uncomment these lines if you want to use your AOAI instance.
#llm = AzureOpenAI(deployment_name="text-davinci-003", model_name="text-davinci-003")
#model = AzureChatOpenAI(deployment_name='gpt-35-turbo',openai_api_type="azure")
llm = OpenAI()
model = ChatOpenAI()
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
prompt_prefix = """
##Instructions:
You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the below tools. Only use the information returned by the below tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
As part of your final answer, ALWAYS include an explanation of how to got to the final answer, including the SQL query you run. Include the explanation and the SQL query in the section that starts with "Explanation:".
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
If the question does not seem related to the database, just return "I don\'t know" as the answer.
##Tools:
"""
prompt_format_instructions = """
Use the following format:
Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [{tool_names}]\nAction Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question.
Explanation:
<===Beging of an Example of Explanation:
I joined the invoices and customers tables on the customer_id column, which is the common key between them. This will allowed me to access the Total and Country columns from both tables. Then I grouped the records by the country column and calculate the sum of the Total column for each country, ordered them in descending order and limited the SELECT to the top 5.
```sql
SELECT c.country AS Country, SUM(i.total) AS Sales
FROM customer c
JOIN invoice i ON c.customer_id = i.customer_id
GROUP BY Country
ORDER BY Sales DESC
LIMIT 5;
```
===>End of an Example of Explanation
"""
agent_executor = create_sql_agent(
prefix=prompt_prefix,
format_instructions = prompt_format_instructions,
llm=llm,
toolkit=toolkit,
verbose=True,
top_k=10
)
if "messages" not in st.session_state or st.sidebar.button("Clear message history"):
st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
for msg in st.session_state.messages:
st.chat_message(msg["role"]).write(msg["content"])
user_query = st.chat_input(placeholder="Ask me anything!")
if user_query:
st.session_state.messages.append({"role": "user", "content": user_query})
st.chat_message("user").write(user_query)
with st.chat_message("assistant"):
st_cb = StreamlitCallbackHandler(st.container())
response = agent_executor.run(user_query, callbacks = [st_cb])
st.session_state.messages.append({"role": "assistant", "content": response})
st.write(response)