Skip to content

Commit

Permalink
feat: upgrade vanna and added streamlit interface
Browse files Browse the repository at this point in the history
  • Loading branch information
matinnuhamunada committed Jun 6, 2024
1 parent 1a2fee4 commit 8648d94
Show file tree
Hide file tree
Showing 6 changed files with 812 additions and 100 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# required by black, https://github.com/psf/black/blob/master/.flake8
max-line-length = 88
max-complexity = 18
ignore = E203, E266, E501, W503, F403, F401
ignore = E203, E266, E501, W503, F403, F401, C901
select = B,C,E,F,W,T4,B9
docstring-convention = google
per-file-ignores =
Expand Down
30 changes: 30 additions & 0 deletions chatbgc/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
"""Console script for chatbgc."""

import logging
import os
import subprocess
import sys
from pathlib import Path

Expand Down Expand Up @@ -77,6 +79,34 @@ def train(
duckdb_path, model=model, training_folder=training_folder, llm_type=llm_type
)

def run_streamlit(self, duckdb_path, model="llama3", llm_type="ollama"):
"""
Starts the chatBGC interface using the Streamlit.
The run method is used to start the chatBGC tool. It connects to a DuckDB database using the vanna.ai library and starts a Flask app.
Parameters:
duckdb_path (str): The path to the DuckDB database.
model (str, optional): The model to use. Defaults to "llama3" for "ollama" and "gpt-4o" for "openai_chat".
llm_type (str, optional): The type of language model to use. Defaults to "ollama". Other option is "openai_chat".
Returns:
None
"""
# Set the environment variables
os.environ["CHATBGC_DUCKDB_PATH"] = duckdb_path
os.environ["CHATBGC_MODEL"] = model
os.environ["CHATBGC_LLM_TYPE"] = llm_type

# Get the directory of the current file
dir_path = Path(__file__).parent

# Construct the path to the streamlit_app.py file
streamlit_app_path = dir_path / "streamlit_app.py"

command = f"streamlit run {streamlit_app_path}"
subprocess.run(command, shell=True)

@staticmethod
def version():
"""
Expand Down
200 changes: 200 additions & 0 deletions chatbgc/streamlit_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
import argparse
import logging
import os
import time

import streamlit as st
from code_editor import code_editor

from chatbgc.vanna_calls import (
generate_followup_cached,
generate_plot_cached,
generate_plotly_code_cached,
generate_questions_cached,
generate_sql_cached,
generate_summary_cached,
is_sql_valid_cached,
run_sql_cached,
setup_vanna,
should_generate_chart_cached,
)


def start_streamlit_app():
"""
This function starts the Streamlit application for ChatBGC.
It sets up the Streamlit page layout, creates a sidebar with output settings,
and initializes the chat interface. It also handles the interaction with the
user, including receiving the user's question, generating and displaying the
SQL query, running the query, and displaying the results.
Environment variables are used to configure the application. The following
environment variables are used:
- CHATBGC_DUCKDB_PATH: The path to the DuckDB database.
- CHATBGC_MODEL: The model to use for generating SQL queries. Defaults to 'llama3'.
- CHATBGC_LLM_TYPE: The type of the LLM. Defaults to 'ollama'.
No arguments are required to call this function.
This function does not return a value.
"""
# Get the parameters from the environment variables
duckdb_path = os.getenv("CHATBGC_DUCKDB_PATH")
model = os.getenv("CHATBGC_MODEL", "llama3")
llm_type = os.getenv("CHATBGC_LLM_TYPE", "ollama")

# Log the values of the environment variables
logging.info(f"CHATBGC_DUCKDB_PATH: {duckdb_path}")
logging.info(f"CHATBGC_MODEL: {model}")
logging.info(f"CHATBGC_LLM_TYPE: {llm_type}")

# Streamlit start
avatar_url = "https://raw.githubusercontent.com/NBChub/chatBGC/main/chatbgc/assets/bgcflow_logo.png"

st.set_page_config(layout="wide")

st.sidebar.title("Output Settings")
st.sidebar.checkbox("Show SQL", value=True, key="show_sql")
st.sidebar.checkbox("Show Table", value=True, key="show_table")
st.sidebar.checkbox("Show Plotly Code", value=True, key="show_plotly_code")
st.sidebar.checkbox("Show Chart", value=True, key="show_chart")
st.sidebar.checkbox("Show Summary", value=True, key="show_summary")
st.sidebar.checkbox("Show Follow-up Questions", value=True, key="show_followup")
st.sidebar.button(
"Reset", on_click=lambda: set_question(None), use_container_width=True
)

st.title("ChatBGC")
# st.sidebar.write(st.session_state)

def set_question(question):
st.session_state["my_question"] = question

# start vanna
vn = setup_vanna(duckdb_path=duckdb_path, model=model, llm_type=llm_type)

assistant_message_suggested = st.chat_message("assistant", avatar=avatar_url)
if assistant_message_suggested.button("Click to show suggested questions"):
st.session_state["my_question"] = None
questions = generate_questions_cached(vn)
for i, question in enumerate(questions):
time.sleep(0.05)
st.button(
question,
on_click=set_question,
args=(question,),
)

my_question = st.session_state.get("my_question", default=None)

if my_question is None:
my_question = st.chat_input(
"Ask me a question about your data",
)

if my_question:
st.session_state["my_question"] = my_question
user_message = st.chat_message("user")
user_message.write(f"{my_question}")

sql = generate_sql_cached(vn, question=my_question)

if sql:
if is_sql_valid_cached(vn, sql=sql):
if st.session_state.get("show_sql", True):
assistant_message_sql = st.chat_message(
"assistant", avatar=avatar_url
)
assistant_message_sql.code(sql, language="sql", line_numbers=True)
else:
assistant_message = st.chat_message("assistant", avatar=avatar_url)
assistant_message.write(sql)
st.stop()

df = run_sql_cached(vn, sql=sql)

if df is not None:
st.session_state["df"] = df

if st.session_state.get("df") is not None:
if st.session_state.get("show_table", True):
df = st.session_state.get("df")
assistant_message_table = st.chat_message(
"assistant",
avatar=avatar_url,
)
if len(df) > 10:
assistant_message_table.text("First 10 rows of data")
assistant_message_table.dataframe(df.head(10))
else:
assistant_message_table.dataframe(df)

if should_generate_chart_cached(
vn, question=my_question, sql=sql, df=df
):

code = generate_plotly_code_cached(
vn, question=my_question, sql=sql, df=df
)

if st.session_state.get("show_plotly_code", False):
assistant_message_plotly_code = st.chat_message(
"assistant",
avatar=avatar_url,
)
assistant_message_plotly_code.code(
code, language="python", line_numbers=True
)

if code is not None and code != "":
if st.session_state.get("show_chart", True):
assistant_message_chart = st.chat_message(
"assistant",
avatar=avatar_url,
)
fig = generate_plot_cached(vn, code=code, df=df)
if fig is not None:
assistant_message_chart.plotly_chart(fig)
else:
assistant_message_chart.error(
"I couldn't generate a chart"
)

if st.session_state.get("show_summary", True):
assistant_message_summary = st.chat_message(
"assistant",
avatar=avatar_url,
)
summary = generate_summary_cached(vn, question=my_question, df=df)
if summary is not None:
assistant_message_summary.text(summary)

if st.session_state.get("show_followup", True):
assistant_message_followup = st.chat_message(
"assistant",
avatar=avatar_url,
)
followup_questions = generate_followup_cached(
vn, question=my_question, sql=sql, df=df
)
st.session_state["df"] = None

if len(followup_questions) > 0:
assistant_message_followup.text(
"Here are some possible follow-up questions"
)
# Print the first 5 follow-up questions
for question in followup_questions[:5]:
assistant_message_followup.button(
question, on_click=set_question, args=(question,)
)

else:
assistant_message_error = st.chat_message("assistant", avatar=avatar_url)
assistant_message_error.error(
"I wasn't able to generate SQL for that question"
)


start_streamlit_app()
121 changes: 121 additions & 0 deletions chatbgc/vanna_calls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import logging
import os

import requests
import streamlit as st
from vanna.chromadb import ChromaDB_VectorStore
from vanna.flask import VannaFlaskApp
from vanna.ollama import Ollama
from vanna.openai import OpenAI_Chat

logging.basicConfig(level=logging.DEBUG)


class MyVannaOllama(ChromaDB_VectorStore, Ollama):
def __init__(self, config=None):
ChromaDB_VectorStore.__init__(self, config=config)
Ollama.__init__(self, config=config)


class MyVannaOpenAI(ChromaDB_VectorStore, OpenAI_Chat):
def __init__(self, config=None):
ChromaDB_VectorStore.__init__(self, config=config)
OpenAI_Chat.__init__(self, config=config)


@st.cache_resource(ttl=3600)
def setup_vanna(duckdb_path, model="llama3", llm_type="ollama"):
"""
Starts the Vanna application with the specified configuration.
Parameters:
duckdb_path (str): The path to the DuckDB database.
model (str, optional): The model to use. Defaults to "llama3" for "ollama" and "gpt-4o" for "openai_chat".
llm_type (str, optional): The type of language model to use. Defaults to "ollama". Other option is "openai_chat".
Returns:
None
"""
if llm_type not in ["ollama", "openai_chat"]:
raise ValueError(f"Invalid LLM type: {llm_type}")

config = {"model": model}

if llm_type == "openai_chat":

# change default model
if model == "llama3":
model = "gpt-4o"
config["model"] = model

openai_api_key = os.environ.get("OPENAI_API_KEY")
if openai_api_key is None:
raise ValueError("OPENAI_API_KEY environment variable is not set")
config["api_key"] = openai_api_key

# get available models
response = requests.get(
"https://api.openai.com/v1/models",
headers={"Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}"},
)

response.raise_for_status()
available_models = [model["id"] for model in response.json()["data"]]
if model not in available_models:
raise ValueError(f"Invalid model. Expected one of: {available_models}")

logging.info(f"Using {llm_type} model: {model}")

if llm_type == "ollama":
vn = MyVannaOllama(config=config)
elif llm_type == "openai_chat":
vn = MyVannaOpenAI(config=config)

vn.connect_to_duckdb(url=duckdb_path)
return vn


@st.cache_data(show_spinner="Generating sample questions ...")
def generate_questions_cached(_vn):
return _vn.generate_questions()


@st.cache_data(show_spinner="Generating SQL query ...")
def generate_sql_cached(_vn, question: str):
return _vn.generate_sql(question=question, allow_llm_to_see_data=True)


@st.cache_data(show_spinner="Checking for valid SQL ...")
def is_sql_valid_cached(_vn, sql: str):
return _vn.is_sql_valid(sql=sql)


@st.cache_data(show_spinner="Running SQL query ...")
def run_sql_cached(_vn, sql: str):
return _vn.run_sql(sql=sql)


@st.cache_data(show_spinner="Checking if we should generate a chart ...")
def should_generate_chart_cached(_vn, question, sql, df):
return _vn.should_generate_chart(df=df)


@st.cache_data(show_spinner="Generating Plotly code ...")
def generate_plotly_code_cached(_vn, question, sql, df):
code = _vn.generate_plotly_code(question=question, sql=sql, df=df)
return code


@st.cache_data(show_spinner="Running Plotly code ...")
def generate_plot_cached(_vn, code, df):
return _vn.get_plotly_figure(plotly_code=code, df=df)


@st.cache_data(show_spinner="Generating followup questions ...")
def generate_followup_cached(_vn, question, sql, df):
return _vn.generate_followup_questions(question=question, sql=sql, df=df)


@st.cache_data(show_spinner="Generating summary ...")
def generate_summary_cached(_vn, question, df):
return _vn.generate_summary(question=question, df=df)
Loading

0 comments on commit 8648d94

Please sign in to comment.