-
Notifications
You must be signed in to change notification settings - Fork 144
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* initial code for sql agent llama Signed-off-by: minmin-intel <[email protected]> * add test for sql agent Signed-off-by: minmin-intel <[email protected]> * update sql agent test Signed-off-by: minmin-intel <[email protected]> * fix bugs and use vllm to test sql agent Signed-off-by: minmin-intel <[email protected]> * add tag-bench test and google search tool Signed-off-by: minmin-intel <[email protected]> * test sql agent with hints Signed-off-by: minmin-intel <[email protected]> * fix bugs for sql agent with hints and update test Signed-off-by: minmin-intel <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add readme for sql agent and fix ci bugs Signed-off-by: minmin-intel <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add sql agent using openai models Signed-off-by: minmin-intel <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix bugs in sql agent openai Signed-off-by: minmin-intel <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make wait time longer for sql agent microservice to be ready Signed-off-by: minmin-intel <[email protected]> * update readme Signed-off-by: minmin-intel <[email protected]> * fix test bug Signed-off-by: minmin-intel <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * skip planexec with vllm due to vllm-gaudi bug Signed-off-by: minmin-intel <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * debug ut issue Signed-off-by: minmin-intel <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * use vllm for all uts Signed-off-by: minmin-intel <[email protected]> * debug ci issue Signed-off-by: minmin-intel <[email protected]> * change vllm port Signed-off-by: minmin-intel <[email protected]> * update ut Signed-off-by: minmin-intel <[email protected]> * remove tgi server Signed-off-by: minmin-intel <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * align vllm port Signed-off-by: minmin-intel <[email protected]> * remove unnecessary files and fix bugs Signed-off-by: minmin-intel <[email protected]> * connect to db with full uri Signed-off-by: minmin-intel <[email protected]> * update readme and use vllm mainstream Signed-off-by: minmin-intel <[email protected]> * rm unnecessary log Signed-off-by: minmin-intel <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update readme Signed-off-by: minmin-intel <[email protected]> * update test script Signed-off-by: minmin-intel <[email protected]> --------- Signed-off-by: minmin-intel <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
70c151d
commit 717c3c1
Showing
28 changed files
with
1,532 additions
and
115 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
File renamed without changes
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# SQL Agents | ||
|
||
We currently have two types of SQL agents: | ||
|
||
1. `sql_agent_llama`: for using with open-source LLMs, especially `meta-llama/Llama-3.1-70B-Instruct` model. | ||
2. `sql_agent`: for using with OpenAI models, we developed and validated with GPT-4o-mini. | ||
|
||
## Overview of sql_agent_llama | ||
|
||
The architecture of `sql_agent_llama` is shown in the figure below. | ||
The agent node takes user question, hints (optional) and history (when available), and thinks step by step to solve the problem. | ||
|
||
![SQL Agent Llama Architecture](../../../assets/sql_agent_llama.png) | ||
|
||
### Database schema: | ||
|
||
We use langchain's [SQLDatabase](https://python.langchain.com/docs/integrations/tools/sql_database/) API to get table names and schemas from the SQL database. User just need to specify `db_path` and `db_name`. The table schemas are incorporated into the prompts for the agent. | ||
|
||
### Hints module: | ||
|
||
If you want to use the hints module, you need to prepare a csv file that has 3 columns: `table_name`, `column_name`, `description`, and make this file available to the agent microservice. The `description` should include useful information (for example, domain knowledge) about a certain column in a table in the database. The hints module will pick up to five relevant columns together with their descriptions based on the user question using similarity search. The hints module will then pass these column descriptions to the agent node. | ||
|
||
### Output parser: | ||
|
||
Due to the current limitations of open source LLMs and serving frameworks (tgi and vllm) in generating tool call objects, we developed and optimized a custom output parser, together with our specially designed prompt templates. The output parser has 3 functions: | ||
|
||
1. Decide if a valid final answer presents in the raw agent output. This is needed because: a) we found sometimes agent would make guess or hallucinate data, so it is critical to double check, b) sometimes LLM does not strictly follow instructions on output format so simple string parsing can fail. We use one additional LLM call to perform this function. | ||
2. Pick out tool calls from raw agent output. And check if the agent has made same tool calls before. If yes, remove the repeated tool calls. | ||
3. Parse and review SQL query, and fix SQL query if there are errors. This proved to improve SQL agent performance since the initial query may contain errors and having a "second pair of eyes" can often spot the errors while the agent node itself may not be able to identify the errors in subsequent execution steps. | ||
|
||
## Overview of sql_agent | ||
|
||
The architecture of `sql_agent` is shown in the figure below. | ||
The agent node takes user question, hints (optional) and history (when available), and thinks step by step to solve the problem. The basic idea is the same as `sql_agent_llama`. However, since OpenAI APIs produce well-structured tool call objects, we don't need a special output parser. Instead, we only keep the query fixer. | ||
|
||
![SQL Agent Architecture](../../../assets/sql_agent.png) | ||
|
||
## Limitations | ||
|
||
1. Agent is only allowed to issue "SELECT" commands to databases, i.e., agent can only query databases but cannot update databases. | ||
2. We currently does not support "streaming" agent outputs on the fly for `sql_agent_llama`. | ||
3. Users need to pass the SQL database URI to the agent with the `db_path` environment variable. We have only validated SQLite database connected in such way. | ||
|
||
Please submit issues if you want new features to be added. We also welcome community contributions! |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from .planner import SQLAgentLlama | ||
from .planner import SQLAgent |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import glob | ||
import os | ||
|
||
import pandas as pd | ||
|
||
|
||
def read_hints(hints_file): | ||
""" | ||
hints_file: csv with columns: table_name, column_name, description | ||
""" | ||
hints_df = pd.read_csv(hints_file) | ||
cols_descriptions = [] | ||
values_descriptions = [] | ||
for _, row in hints_df.iterrows(): | ||
table_name = row["table_name"] | ||
col_name = row["column_name"] | ||
description = row["description"] | ||
if not pd.isnull(description): | ||
cols_descriptions.append(f"{table_name}.{col_name}: {description}") | ||
values_descriptions.append(f"{col_name}: {description}") | ||
return cols_descriptions, values_descriptions | ||
|
||
|
||
def sort_list(list1, list2): | ||
import numpy as np | ||
|
||
# Use numpy's argsort function to get the indices that would sort the second list | ||
idx = np.argsort(list2) # ascending order | ||
return np.array(list1)[idx].tolist()[::-1], np.array(list2)[idx].tolist()[::-1] # descending order | ||
|
||
|
||
def get_topk_cols(topk, cols_descriptions, similarities): | ||
sorted_cols, similarities = sort_list(cols_descriptions, similarities) | ||
top_k_cols = sorted_cols[:topk] | ||
output = [] | ||
for col, sim in zip(top_k_cols, similarities[:topk]): | ||
# print(f"{col}: {sim}") | ||
if sim > 0.5: | ||
output.append(col) | ||
return output | ||
|
||
|
||
def pick_hints(query, model, column_embeddings, complete_descriptions, topk=5): | ||
if len(complete_descriptions) < topk: | ||
topk_cols_descriptions = complete_descriptions | ||
else: | ||
# use similarity to get the topk columns | ||
query_embedding = model.encode(query, convert_to_tensor=True) | ||
similarities = model.similarity(query_embedding, column_embeddings).flatten() | ||
topk_cols_descriptions = get_topk_cols(topk, complete_descriptions, similarities) | ||
|
||
hint = "" | ||
for col in topk_cols_descriptions: | ||
hint += col + "\n" | ||
return hint |
Oops, something went wrong.