-
Notifications
You must be signed in to change notification settings - Fork 1
/
sql_gen.py
54 lines (46 loc) · 1.35 KB
/
sql_gen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from openai import OpenAI
import utils
import dotenv
import os
from streamlit_db_chat import hf_api_call
dotenv.load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
hf_url, hf_token = hf_api_call()
client = OpenAI(
base_url=hf_url,
api_key=hf_token
)
def llm_request(content, schema, database_uri):
# Fetch sample data from the database for context
db_util = utils.DatabaseUtil(database_uri)
top_five_rows_data = db_util.top_five_rows()
### Connecting to HF endpoint ###
chat_completion = client.chat.completions.create(
model="tgi",
messages=[
{
"role": "system",
"content": "Below is the schema to the database. Create and send only the accurate sql queries for the schema."
'''{}'''
'''These are the top 5 distinct rows:
{}'''
.format(schema, top_five_rows_data)
},
{
"role": "user",
"content": content
}
],
# Change the below as per requirements:
# top_p=None,
# temperature=None,
# max_tokens=150,
# stream=True,
# seed=None,
# frequency_penalty=None,
# presence_penalty=None
)
sql = ""
for message in chat_completion:
sql += message.choices[0].delta.content
return sql