From 3e854d19bc3ee7a768cf822526a094cfadf86ede Mon Sep 17 00:00:00 2001 From: Matin Nuhamunada Date: Mon, 22 Apr 2024 15:06:35 +0000 Subject: [PATCH] feat: addability to train with question-sql pair --- chatbgc/app.py | 1 - chatbgc/train.py | 8 +++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/chatbgc/app.py b/chatbgc/app.py index 0298b4c..245addd 100644 --- a/chatbgc/app.py +++ b/chatbgc/app.py @@ -47,7 +47,6 @@ def start_app(duckdb_path, model="duckdb-nsql", llm_type="ollama"): model = "gpt-4" config["model"] = model logging.info(f"Using {llm_type} model: {model}") - logging.debug(f"OpenAI API key: {openai_api_key}") if llm_type == "ollama": vn = MyVannaOllama(config=config) diff --git a/chatbgc/train.py b/chatbgc/train.py index cfa4b41..6a30871 100644 --- a/chatbgc/train.py +++ b/chatbgc/train.py @@ -1,3 +1,4 @@ +import json import logging import os from pathlib import Path @@ -39,7 +40,6 @@ def train_model( model = "gpt-4" config["model"] = model logging.info(f"Using {llm_type} model: {model}") - logging.debug(f"OpenAI API key: {openai_api_key}") if llm_type == "ollama": vn = MyVannaOllama(config=config) @@ -59,6 +59,12 @@ def train_model( docs = f.read() vn.train(documentation=docs) logging.info(f"Trained on documentation file: {file}") + elif file.suffix == ".json": + with open(file, "r") as f: + data = json.load(f) + for item in data: + vn.train(question=item["question"], sql=item["sql"]) + logging.info(f"Trained on question-sql pair: {file}") # Get the information schema query df_information_schema = vn.run_sql("SELECT * FROM INFORMATION_SCHEMA.COLUMNS")