-
Notifications
You must be signed in to change notification settings - Fork 145
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor text2sql based on ERAG (#1080)
Signed-off-by: Yao, Qing <[email protected]>
- Loading branch information
Showing
20 changed files
with
479 additions
and
421 deletions.
There are no files selected for viewing
Validating CODEOWNERS rules …
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,7 @@ | |
/comps/prompt_registry/ [email protected] | ||
/comps/feedback_management/ [email protected] | ||
/comps/chathistory/ [email protected] | ||
/comps/texttosql/ [email protected] | ||
/comps/text2sql/ [email protected] | ||
/comps/text2image/ [email protected] | ||
/comps/reranks/ [email protected] | ||
/comps/retrievers/ [email protected] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
# 🛢 Text-to-SQL Microservice | ||
|
||
In today's data-driven world, the ability to efficiently extract insights from databases is crucial. However, querying databases often requires specialized knowledge of SQL(Structured Query Language) and database schemas, which can be a barrier for non-technical users. This is where the Text-to-SQL microservice comes into play, leveraging the power of LLMs and agentic frameworks to bridge the gap between human language and database queries. This microservice is built on LangChain/LangGraph frameworks. | ||
|
||
The microservice enables a wide range of use cases, making it a versatile tool for businesses, researchers, and individuals alike. Users can generate queries based on natural language questions, enabling them to quickly retrieve relevant data from their databases. Additionally, the service can be integrated into ChatBots, allowing for natural language interactions and providing accurate responses based on the underlying data. Furthermore, it can be utilized to build custom dashboards, enabling users to visualize and analyze insights based on their specific requirements, all through the power of natural language. | ||
|
||
--- | ||
|
||
## 🛠️ Features | ||
|
||
**Implement SQL Query based on input text**: Transform user-provided natural language into SQL queries, subsequently executing them to retrieve data from SQL databases. | ||
|
||
--- | ||
|
||
## ⚙️ Implementation | ||
|
||
The text-to-sql microservice able to implement with various framework and support various types of SQL databases. | ||
|
||
### 🔗 Utilizing Text-to-SQL with Langchain framework | ||
|
||
The follow guide provides set-up instructions and comprehensive details regarding the Text-to-SQL microservices via LangChain. In this configuration, we will employ PostgresDB as our example database to showcase this microservice. | ||
|
||
--- | ||
|
||
#### 🚀 Start Microservice with Python(Option 1) | ||
|
||
#### Install Requirements | ||
|
||
```bash | ||
pip install -r requirements.txt | ||
``` | ||
|
||
#### Start PostgresDB Service | ||
|
||
We will use [Chinook](https://github.com/lerocha/chinook-database) sample database as a default to test the Text-to-SQL microservice. Chinook database is a sample database ideal for demos and testing ORM tools targeting single and multiple database servers. | ||
|
||
```bash | ||
export POSTGRES_USER=postgres | ||
export POSTGRES_PASSWORD=testpwd | ||
export POSTGRES_DB=chinook | ||
|
||
cd comps/text2sql | ||
|
||
docker run --name postgres-db --ipc=host -e POSTGRES_USER=${POSTGRES_USER} -e POSTGRES_HOST_AUTH_METHOD=trust -e POSTGRES_DB=${POSTGRES_DB} -e POSTGRES_PASSWORD=${POSTGRES_PASSWORD} -p 5442:5432 -d -v ./chinook.sql:/docker-entrypoint-initdb.d/chinook.sql postgres:latest | ||
``` | ||
|
||
#### Start TGI Service | ||
|
||
```bash | ||
export HUGGINGFACEHUB_API_TOKEN=${HUGGINGFACEHUB_API_TOKEN} | ||
export LLM_MODEL_ID="mistralai/Mistral-7B-Instruct-v0.3" | ||
export TGI_PORT=8008 | ||
|
||
docker run -d --name="text2sql-tgi-endpoint" --ipc=host -p $TGI_PORT:80 -v ./data:/data --shm-size 1g -e HF_TOKEN=${HUGGINGFACEHUB_API_TOKEN} -e model=${LLM_MODEL_ID} ghcr.io/huggingface/text-generation-inference:2.1.0 --model-id $LLM_MODEL_ID | ||
``` | ||
|
||
#### Verify the TGI Service | ||
|
||
```bash | ||
export your_ip=$(hostname -I | awk '{print $1}') | ||
curl http://${your_ip}:${TGI_PORT}/generate \ | ||
-X POST \ | ||
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":17, "do_sample": true}}' \ | ||
-H 'Content-Type: application/json' | ||
``` | ||
|
||
#### Setup Environment Variables | ||
|
||
```bash | ||
export TGI_LLM_ENDPOINT="http://${your_ip}:${TGI_PORT}" | ||
``` | ||
|
||
#### Start Text-to-SQL Microservice with Python Script | ||
|
||
Start Text-to-SQL microservice with below command. | ||
|
||
```bash | ||
python3 opea_text2sql_microservice.py | ||
``` | ||
|
||
--- | ||
|
||
### 🚀 Start Microservice with Docker (Option 2) | ||
|
||
#### Start PostGreSQL Database Service | ||
|
||
Please refer to section [Start PostgresDB Service](#start-postgresdb-service) | ||
|
||
#### Start TGI Service | ||
|
||
Please refer to section [Start TGI Service](#start-tgi-service) | ||
|
||
#### Setup Environment Variables | ||
|
||
```bash | ||
export TGI_LLM_ENDPOINT="http://${your_ip}:${TGI_PORT}" | ||
``` | ||
|
||
#### Build Docker Image | ||
|
||
```bash | ||
cd GenAIComps/ | ||
docker build -t opea/text2sql:latest -f comps/text2sql/src/Dockerfile . | ||
``` | ||
|
||
#### Run Docker with CLI (Option A) | ||
|
||
```bash | ||
export TGI_LLM_ENDPOINT="http://${your_ip}:${TGI_PORT}" | ||
|
||
docker run --runtime=runc --name="comps-langchain-text2sql" -p 9090:8080 --ipc=host -e llm_endpoint_url=${TGI_LLM_ENDPOINT} opea/text2sql:latest | ||
``` | ||
|
||
#### Run via docker compose (Option B) | ||
|
||
- Setup Environment Variables. | ||
|
||
```bash | ||
export TGI_LLM_ENDPOINT=http://${your_ip}:${TGI_PORT} | ||
export HF_TOKEN=${HUGGINGFACEHUB_API_TOKEN} | ||
export LLM_MODEL_ID="mistralai/Mistral-7B-Instruct-v0.3" | ||
export POSTGRES_USER=postgres | ||
export POSTGRES_PASSWORD=testpwd | ||
export POSTGRES_DB=chinook | ||
``` | ||
|
||
- Start the services. | ||
|
||
```bash | ||
docker compose -f docker_compose_text2sql.yaml up | ||
``` | ||
|
||
--- | ||
|
||
### ✅ Invoke the microservice. | ||
|
||
The Text-to-SQL microservice exposes the following API endpoints: | ||
|
||
- Test Database Connection | ||
|
||
```bash | ||
curl --location http://${your_ip}:9090/v1/postgres/health \ | ||
--header 'Content-Type: application/json' \ | ||
--data '{"user": "'${POSTGRES_USER}'","password": "'${POSTGRES_PASSWORD}'","host": "'${your_ip}'", "port": "5442", "database": "'${POSTGRES_DB}'"}' | ||
``` | ||
|
||
- Execute SQL Query from input text | ||
|
||
```bash | ||
curl http://${your_ip}:9090/v1/text2sql\ | ||
-X POST \ | ||
-d '{"input_text": "Find the total number of Albums.","conn_str": {"user": "'${POSTGRES_USER}'","password": "'${POSTGRES_PASSWORD}'","host": "'${your_ip}'", "port": "5442", "database": "'${POSTGRES_DB}'"}}' \ | ||
-H 'Content-Type: application/json' | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from __future__ import annotations | ||
|
||
import os | ||
from typing import Annotated, Optional | ||
|
||
from langchain.agents.agent_types import AgentType | ||
from langchain_community.utilities.sql_database import SQLDatabase | ||
from langchain_huggingface import HuggingFaceEndpoint | ||
from pydantic import BaseModel, Field | ||
from sqlalchemy import create_engine | ||
from sqlalchemy.exc import SQLAlchemyError | ||
|
||
from comps import CustomLogger, OpeaComponent, ServiceType | ||
from comps.text2sql.src.integrations.sql_agent import CustomSQLDatabaseToolkit, custom_create_sql_agent | ||
|
||
logger = CustomLogger("comps-text2sql") | ||
logflag = os.getenv("LOGFLAG", False) | ||
|
||
sql_params = { | ||
"max_string_length": 3600, | ||
} | ||
|
||
generation_params = { | ||
"max_new_tokens": 1024, | ||
"top_k": 10, | ||
"top_p": 0.95, | ||
"temperature": 0.01, | ||
"repetition_penalty": 1.03, | ||
"streaming": True, | ||
} | ||
|
||
TGI_LLM_ENDPOINT = os.environ.get("TGI_LLM_ENDPOINT") | ||
|
||
llm = HuggingFaceEndpoint( | ||
endpoint_url=TGI_LLM_ENDPOINT, | ||
task="text-generation", | ||
**generation_params, | ||
) | ||
|
||
|
||
class PostgresConnection(BaseModel): | ||
user: Annotated[str, Field(min_length=1)] | ||
password: Annotated[str, Field(min_length=1)] | ||
host: Annotated[str, Field(min_length=1)] | ||
port: Annotated[int, Field(ge=1, le=65535)] # Default PostgreSQL port with constraints | ||
database: Annotated[str, Field(min_length=1)] | ||
|
||
def connection_string(self) -> str: | ||
return f"postgresql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}" | ||
|
||
def test_connection(self) -> bool: | ||
"""Test the connection to the PostgreSQL database.""" | ||
connection_string = self.connection_string() | ||
try: | ||
engine = create_engine(connection_string) | ||
with engine.connect() as _: | ||
# If the connection is successful, return True | ||
return True | ||
except SQLAlchemyError as e: | ||
print(f"Connection failed: {e}") | ||
return False | ||
|
||
|
||
class Input(BaseModel): | ||
input_text: str | ||
conn_str: Optional[PostgresConnection] = None | ||
|
||
|
||
class OpeaText2SQL(OpeaComponent): | ||
"""A specialized text to sql component derived from OpeaComponent for interacting with TGI services and Database. | ||
Attributes: | ||
client: An instance of the client for text to sql generation and execution. | ||
""" | ||
|
||
def __init__(self, name: str, description: str, config: dict = None): | ||
super().__init__(name, ServiceType.TEXT2SQL.name.lower(), description, config) | ||
|
||
async def check_health(self) -> bool: | ||
"""Checks the health of the TGI service. | ||
Returns: | ||
bool: True if the service is reachable and healthy, False otherwise. | ||
""" | ||
try: | ||
response = llm.generate(["Hello, how are you?"]) | ||
return True | ||
except Exception as e: | ||
return False | ||
|
||
async def invoke(self, input: Input): | ||
url = input.conn_str.connection_string() | ||
"""Execute a SQL query using the custom SQL agent. | ||
Args: | ||
input (str): The user's input. | ||
url (str): The URL of the database to connect to. | ||
Returns: | ||
dict: The result of the SQL execution. | ||
""" | ||
db = SQLDatabase.from_uri(url, **sql_params) | ||
logger.info("Starting Agent") | ||
agent_executor = custom_create_sql_agent( | ||
llm=llm, | ||
verbose=True, | ||
toolkit=CustomSQLDatabaseToolkit(llm=llm, db=db), | ||
agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION, | ||
agent_executor_kwargs={"return_intermediate_steps": True}, | ||
) | ||
|
||
result = await agent_executor.ainvoke(input) | ||
|
||
query = [] | ||
for log, _ in result["intermediate_steps"]: | ||
if log.tool == "sql_db_query": | ||
query.append(log.tool_input) | ||
result["sql"] = query[0].replace("Observation", "") | ||
return result |
Oops, something went wrong.