Skip to content

Commit

Permalink
Add SQL agent strategy (#1039)
Browse files Browse the repository at this point in the history
* initial code for sql agent llama

Signed-off-by: minmin-intel <[email protected]>

* add test for sql agent

Signed-off-by: minmin-intel <[email protected]>

* update sql agent test

Signed-off-by: minmin-intel <[email protected]>

* fix bugs and use vllm to test sql agent

Signed-off-by: minmin-intel <[email protected]>

* add tag-bench test and google search tool

Signed-off-by: minmin-intel <[email protected]>

* test sql agent with hints

Signed-off-by: minmin-intel <[email protected]>

* fix bugs for sql agent with hints and update test

Signed-off-by: minmin-intel <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add readme for sql agent and fix ci bugs

Signed-off-by: minmin-intel <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add sql agent using openai models

Signed-off-by: minmin-intel <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix bugs in sql agent openai

Signed-off-by: minmin-intel <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* make wait time longer for sql agent microservice to be ready

Signed-off-by: minmin-intel <[email protected]>

* update readme

Signed-off-by: minmin-intel <[email protected]>

* fix test bug

Signed-off-by: minmin-intel <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* skip planexec with vllm due to vllm-gaudi bug

Signed-off-by: minmin-intel <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* debug ut issue

Signed-off-by: minmin-intel <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use vllm for all uts

Signed-off-by: minmin-intel <[email protected]>

* debug ci issue

Signed-off-by: minmin-intel <[email protected]>

* change vllm port

Signed-off-by: minmin-intel <[email protected]>

* update ut

Signed-off-by: minmin-intel <[email protected]>

* remove tgi server

Signed-off-by: minmin-intel <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* align vllm port

Signed-off-by: minmin-intel <[email protected]>

* remove unnecessary files and fix bugs

Signed-off-by: minmin-intel <[email protected]>

* connect to db with full uri

Signed-off-by: minmin-intel <[email protected]>

* update readme and use vllm mainstream

Signed-off-by: minmin-intel <[email protected]>

* rm unnecessary log

Signed-off-by: minmin-intel <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update readme

Signed-off-by: minmin-intel <[email protected]>

* update test script

Signed-off-by: minmin-intel <[email protected]>

---------

Signed-off-by: minmin-intel <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
minmin-intel and pre-commit-ci[bot] authored Dec 18, 2024
1 parent 70c151d commit 717c3c1
Show file tree
Hide file tree
Showing 28 changed files with 1,532 additions and 115 deletions.
64 changes: 30 additions & 34 deletions comps/agent/langchain/README.md

Large diffs are not rendered by default.

File renamed without changes
Binary file added comps/agent/langchain/assets/sql_agent.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added comps/agent/langchain/assets/sql_agent_llama.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 3 additions & 3 deletions comps/agent/langchain/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# used by microservice
docarray[full]

#used by tools
duckduckgo-search
fastapi
huggingface_hub
langchain

#used by tools
langchain-google-community
langchain-huggingface
langchain-openai
langchain_community
Expand Down
10 changes: 10 additions & 0 deletions comps/agent/langchain/src/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,15 @@ def instantiate_agent(args, strategy="react_langchain", with_memory=False):
from .strategy.ragagent import RAGAgent

return RAGAgent(args, with_memory, custom_prompt=custom_prompt)
elif strategy == "sql_agent_llama":
print("Initializing SQL Agent Llama")
from .strategy.sqlagent import SQLAgentLlama

return SQLAgentLlama(args, with_memory, custom_prompt=custom_prompt)
elif strategy == "sql_agent":
print("Initializing SQL Agent")
from .strategy.sqlagent import SQLAgent

return SQLAgent(args, with_memory, custom_prompt=custom_prompt)
else:
raise ValueError(f"Agent strategy: {strategy} not supported!")
13 changes: 13 additions & 0 deletions comps/agent/langchain/src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,16 @@

if os.environ.get("timeout") is not None:
env_config += ["--timeout", os.environ["timeout"]]

# for sql agent
if os.environ.get("db_path") is not None:
env_config += ["--db_path", os.environ["db_path"]]

if os.environ.get("db_name") is not None:
env_config += ["--db_name", os.environ["db_name"]]

if os.environ.get("use_hints") is not None:
env_config += ["--use_hints", os.environ["use_hints"]]

if os.environ.get("hints_file") is not None:
env_config += ["--hints_file", os.environ["hints_file"]]
34 changes: 33 additions & 1 deletion comps/agent/langchain/src/strategy/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,37 @@ def compile(self):
def execute(self, state: dict):
pass

def non_streaming_run(self, query, config):
def prepare_initial_state(self, query):
raise NotImplementedError

async def stream_generator(self, query, config):
initial_state = self.prepare_initial_state(query)
try:
async for event in self.app.astream(initial_state, config=config):
for node_name, node_state in event.items():
yield f"--- CALL {node_name} ---\n"
for k, v in node_state.items():
if v is not None:
yield f"{k}: {v}\n"

yield f"data: {repr(event)}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
yield str(e)

async def non_streaming_run(self, query, config):
initial_state = self.prepare_initial_state(query)
print("@@@ Initial State: ", initial_state)
try:
async for s in self.app.astream(initial_state, config=config, stream_mode="values"):
message = s["messages"][-1]
if isinstance(message, tuple):
print(message)
else:
message.pretty_print()

last_message = s["messages"][-1]
print("******Response: ", last_message.content)
return last_message.content
except Exception as e:
return str(e)
44 changes: 44 additions & 0 deletions comps/agent/langchain/src/strategy/sqlagent/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# SQL Agents

We currently have two types of SQL agents:

1. `sql_agent_llama`: for using with open-source LLMs, especially `meta-llama/Llama-3.1-70B-Instruct` model.
2. `sql_agent`: for using with OpenAI models, we developed and validated with GPT-4o-mini.

## Overview of sql_agent_llama

The architecture of `sql_agent_llama` is shown in the figure below.
The agent node takes user question, hints (optional) and history (when available), and thinks step by step to solve the problem.

![SQL Agent Llama Architecture](../../../assets/sql_agent_llama.png)

### Database schema:

We use langchain's [SQLDatabase](https://python.langchain.com/docs/integrations/tools/sql_database/) API to get table names and schemas from the SQL database. User just need to specify `db_path` and `db_name`. The table schemas are incorporated into the prompts for the agent.

### Hints module:

If you want to use the hints module, you need to prepare a csv file that has 3 columns: `table_name`, `column_name`, `description`, and make this file available to the agent microservice. The `description` should include useful information (for example, domain knowledge) about a certain column in a table in the database. The hints module will pick up to five relevant columns together with their descriptions based on the user question using similarity search. The hints module will then pass these column descriptions to the agent node.

### Output parser:

Due to the current limitations of open source LLMs and serving frameworks (tgi and vllm) in generating tool call objects, we developed and optimized a custom output parser, together with our specially designed prompt templates. The output parser has 3 functions:

1. Decide if a valid final answer presents in the raw agent output. This is needed because: a) we found sometimes agent would make guess or hallucinate data, so it is critical to double check, b) sometimes LLM does not strictly follow instructions on output format so simple string parsing can fail. We use one additional LLM call to perform this function.
2. Pick out tool calls from raw agent output. And check if the agent has made same tool calls before. If yes, remove the repeated tool calls.
3. Parse and review SQL query, and fix SQL query if there are errors. This proved to improve SQL agent performance since the initial query may contain errors and having a "second pair of eyes" can often spot the errors while the agent node itself may not be able to identify the errors in subsequent execution steps.

## Overview of sql_agent

The architecture of `sql_agent` is shown in the figure below.
The agent node takes user question, hints (optional) and history (when available), and thinks step by step to solve the problem. The basic idea is the same as `sql_agent_llama`. However, since OpenAI APIs produce well-structured tool call objects, we don't need a special output parser. Instead, we only keep the query fixer.

![SQL Agent Architecture](../../../assets/sql_agent.png)

## Limitations

1. Agent is only allowed to issue "SELECT" commands to databases, i.e., agent can only query databases but cannot update databases.
2. We currently does not support "streaming" agent outputs on the fly for `sql_agent_llama`.
3. Users need to pass the SQL database URI to the agent with the `db_path` environment variable. We have only validated SQLite database connected in such way.

Please submit issues if you want new features to be added. We also welcome community contributions!
5 changes: 5 additions & 0 deletions comps/agent/langchain/src/strategy/sqlagent/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .planner import SQLAgentLlama
from .planner import SQLAgent
58 changes: 58 additions & 0 deletions comps/agent/langchain/src/strategy/sqlagent/hint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import glob
import os

import pandas as pd


def read_hints(hints_file):
"""
hints_file: csv with columns: table_name, column_name, description
"""
hints_df = pd.read_csv(hints_file)
cols_descriptions = []
values_descriptions = []
for _, row in hints_df.iterrows():
table_name = row["table_name"]
col_name = row["column_name"]
description = row["description"]
if not pd.isnull(description):
cols_descriptions.append(f"{table_name}.{col_name}: {description}")
values_descriptions.append(f"{col_name}: {description}")
return cols_descriptions, values_descriptions


def sort_list(list1, list2):
import numpy as np

# Use numpy's argsort function to get the indices that would sort the second list
idx = np.argsort(list2) # ascending order
return np.array(list1)[idx].tolist()[::-1], np.array(list2)[idx].tolist()[::-1] # descending order


def get_topk_cols(topk, cols_descriptions, similarities):
sorted_cols, similarities = sort_list(cols_descriptions, similarities)
top_k_cols = sorted_cols[:topk]
output = []
for col, sim in zip(top_k_cols, similarities[:topk]):
# print(f"{col}: {sim}")
if sim > 0.5:
output.append(col)
return output


def pick_hints(query, model, column_embeddings, complete_descriptions, topk=5):
if len(complete_descriptions) < topk:
topk_cols_descriptions = complete_descriptions
else:
# use similarity to get the topk columns
query_embedding = model.encode(query, convert_to_tensor=True)
similarities = model.similarity(query_embedding, column_embeddings).flatten()
topk_cols_descriptions = get_topk_cols(topk, complete_descriptions, similarities)

hint = ""
for col in topk_cols_descriptions:
hint += col + "\n"
return hint
Loading

0 comments on commit 717c3c1

Please sign in to comment.