Skip to content

Commit

Permalink
feat: addability to train with question-sql pair
Browse files Browse the repository at this point in the history
  • Loading branch information
matinnuhamunada committed Apr 22, 2024
1 parent c0db2a4 commit 3e854d1
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
1 change: 0 additions & 1 deletion chatbgc/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def start_app(duckdb_path, model="duckdb-nsql", llm_type="ollama"):
model = "gpt-4"
config["model"] = model
logging.info(f"Using {llm_type} model: {model}")
logging.debug(f"OpenAI API key: {openai_api_key}")

if llm_type == "ollama":
vn = MyVannaOllama(config=config)
Expand Down
8 changes: 7 additions & 1 deletion chatbgc/train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import logging
import os
from pathlib import Path
Expand Down Expand Up @@ -39,7 +40,6 @@ def train_model(
model = "gpt-4"
config["model"] = model
logging.info(f"Using {llm_type} model: {model}")
logging.debug(f"OpenAI API key: {openai_api_key}")

if llm_type == "ollama":
vn = MyVannaOllama(config=config)
Expand All @@ -59,6 +59,12 @@ def train_model(
docs = f.read()
vn.train(documentation=docs)
logging.info(f"Trained on documentation file: {file}")
elif file.suffix == ".json":
with open(file, "r") as f:
data = json.load(f)
for item in data:
vn.train(question=item["question"], sql=item["sql"])
logging.info(f"Trained on question-sql pair: {file}")

# Get the information schema query
df_information_schema = vn.run_sql("SELECT * FROM INFORMATION_SCHEMA.COLUMNS")
Expand Down

0 comments on commit 3e854d1

Please sign in to comment.