Skip to content

Commit

Permalink
Refactor text2sql based on ERAG (#1080)
Browse files Browse the repository at this point in the history
Signed-off-by: Yao, Qing <[email protected]>
  • Loading branch information
yao531441 authored Jan 2, 2025
1 parent 90a8634 commit 2cfd014
Show file tree
Hide file tree
Showing 20 changed files with 479 additions and 421 deletions.
2 changes: 1 addition & 1 deletion .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
/comps/prompt_registry/ [email protected]
/comps/feedback_management/ [email protected]
/comps/chathistory/ [email protected]
/comps/texttosql/ [email protected]
/comps/text2sql/ [email protected]
/comps/text2image/ [email protected]
/comps/reranks/ [email protected]
/comps/retrievers/ [email protected]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

# this file should be run in the root of the repo
services:
texttosql:
text2sql:
build:
dockerfile: comps/texttosql/langchain/Dockerfile
image: ${REGISTRY:-opea}/texttosql:${TAG:-latest}
dockerfile: comps/text2sql/src/Dockerfile
image: ${REGISTRY:-opea}/text2sql:${TAG:-latest}
1 change: 1 addition & 0 deletions comps/cores/mega/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class ServiceType(Enum):
TEXT2IMAGE = 16
ANIMATION = 17
IMAGE2IMAGE = 18
TEXT2SQL = 19


class MegaServiceEndpoint(Enum):
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ services:
volumes:
- ./chinook.sql:/docker-entrypoint-initdb.d/chinook.sql

texttosql_service:
image: opea/texttosql:latest
container_name: texttosql_service
text2sql_service:
image: opea/text2sql:latest
container_name: text2sql_service
ports:
- "9090:8090"
environment:
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ COPY comps /home/user/comps

RUN pip install --no-cache-dir --upgrade pip setuptools && \
if [ ${ARCH} = "cpu" ]; then \
pip install --no-cache-dir --extra-index-url https://download.pytorch.org/whl/cpu -r /home/user/comps/texttosql/langchain/requirements.txt; \
pip install --no-cache-dir --extra-index-url https://download.pytorch.org/whl/cpu -r /home/user/comps/text2sql/src/requirements.txt; \
else \
pip install --no-cache-dir -r /home/user/comps/texttosql/langchain/requirements.txt; \
pip install --no-cache-dir -r /home/user/comps/text2sql/src/requirements.txt; \
fi

ENV PYTHONPATH=$PYTHONPATH:/home/user

WORKDIR /home/user/comps/texttosql/langchain/
WORKDIR /home/user/comps/text2sql/src/

ENTRYPOINT ["python", "main.py"]
ENTRYPOINT ["python", "opea_text2sql_microservice.py"]
154 changes: 154 additions & 0 deletions comps/text2sql/src/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# 🛢 Text-to-SQL Microservice

In today's data-driven world, the ability to efficiently extract insights from databases is crucial. However, querying databases often requires specialized knowledge of SQL(Structured Query Language) and database schemas, which can be a barrier for non-technical users. This is where the Text-to-SQL microservice comes into play, leveraging the power of LLMs and agentic frameworks to bridge the gap between human language and database queries. This microservice is built on LangChain/LangGraph frameworks.

The microservice enables a wide range of use cases, making it a versatile tool for businesses, researchers, and individuals alike. Users can generate queries based on natural language questions, enabling them to quickly retrieve relevant data from their databases. Additionally, the service can be integrated into ChatBots, allowing for natural language interactions and providing accurate responses based on the underlying data. Furthermore, it can be utilized to build custom dashboards, enabling users to visualize and analyze insights based on their specific requirements, all through the power of natural language.

---

## 🛠️ Features

**Implement SQL Query based on input text**: Transform user-provided natural language into SQL queries, subsequently executing them to retrieve data from SQL databases.

---

## ⚙️ Implementation

The text-to-sql microservice able to implement with various framework and support various types of SQL databases.

### 🔗 Utilizing Text-to-SQL with Langchain framework

The follow guide provides set-up instructions and comprehensive details regarding the Text-to-SQL microservices via LangChain. In this configuration, we will employ PostgresDB as our example database to showcase this microservice.

---

#### 🚀 Start Microservice with Python(Option 1)

#### Install Requirements

```bash
pip install -r requirements.txt
```

#### Start PostgresDB Service

We will use [Chinook](https://github.com/lerocha/chinook-database) sample database as a default to test the Text-to-SQL microservice. Chinook database is a sample database ideal for demos and testing ORM tools targeting single and multiple database servers.

```bash
export POSTGRES_USER=postgres
export POSTGRES_PASSWORD=testpwd
export POSTGRES_DB=chinook

cd comps/text2sql

docker run --name postgres-db --ipc=host -e POSTGRES_USER=${POSTGRES_USER} -e POSTGRES_HOST_AUTH_METHOD=trust -e POSTGRES_DB=${POSTGRES_DB} -e POSTGRES_PASSWORD=${POSTGRES_PASSWORD} -p 5442:5432 -d -v ./chinook.sql:/docker-entrypoint-initdb.d/chinook.sql postgres:latest
```

#### Start TGI Service

```bash
export HUGGINGFACEHUB_API_TOKEN=${HUGGINGFACEHUB_API_TOKEN}
export LLM_MODEL_ID="mistralai/Mistral-7B-Instruct-v0.3"
export TGI_PORT=8008

docker run -d --name="text2sql-tgi-endpoint" --ipc=host -p $TGI_PORT:80 -v ./data:/data --shm-size 1g -e HF_TOKEN=${HUGGINGFACEHUB_API_TOKEN} -e model=${LLM_MODEL_ID} ghcr.io/huggingface/text-generation-inference:2.1.0 --model-id $LLM_MODEL_ID
```

#### Verify the TGI Service

```bash
export your_ip=$(hostname -I | awk '{print $1}')
curl http://${your_ip}:${TGI_PORT}/generate \
-X POST \
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":17, "do_sample": true}}' \
-H 'Content-Type: application/json'
```

#### Setup Environment Variables

```bash
export TGI_LLM_ENDPOINT="http://${your_ip}:${TGI_PORT}"
```

#### Start Text-to-SQL Microservice with Python Script

Start Text-to-SQL microservice with below command.

```bash
python3 opea_text2sql_microservice.py
```

---

### 🚀 Start Microservice with Docker (Option 2)

#### Start PostGreSQL Database Service

Please refer to section [Start PostgresDB Service](#start-postgresdb-service)

#### Start TGI Service

Please refer to section [Start TGI Service](#start-tgi-service)

#### Setup Environment Variables

```bash
export TGI_LLM_ENDPOINT="http://${your_ip}:${TGI_PORT}"
```

#### Build Docker Image

```bash
cd GenAIComps/
docker build -t opea/text2sql:latest -f comps/text2sql/src/Dockerfile .
```

#### Run Docker with CLI (Option A)

```bash
export TGI_LLM_ENDPOINT="http://${your_ip}:${TGI_PORT}"

docker run --runtime=runc --name="comps-langchain-text2sql" -p 9090:8080 --ipc=host -e llm_endpoint_url=${TGI_LLM_ENDPOINT} opea/text2sql:latest
```

#### Run via docker compose (Option B)

- Setup Environment Variables.

```bash
export TGI_LLM_ENDPOINT=http://${your_ip}:${TGI_PORT}
export HF_TOKEN=${HUGGINGFACEHUB_API_TOKEN}
export LLM_MODEL_ID="mistralai/Mistral-7B-Instruct-v0.3"
export POSTGRES_USER=postgres
export POSTGRES_PASSWORD=testpwd
export POSTGRES_DB=chinook
```

- Start the services.

```bash
docker compose -f docker_compose_text2sql.yaml up
```

---

### ✅ Invoke the microservice.

The Text-to-SQL microservice exposes the following API endpoints:

- Test Database Connection

```bash
curl --location http://${your_ip}:9090/v1/postgres/health \
--header 'Content-Type: application/json' \
--data '{"user": "'${POSTGRES_USER}'","password": "'${POSTGRES_PASSWORD}'","host": "'${your_ip}'", "port": "5442", "database": "'${POSTGRES_DB}'"}'
```

- Execute SQL Query from input text

```bash
curl http://${your_ip}:9090/v1/text2sql\
-X POST \
-d '{"input_text": "Find the total number of Albums.","conn_str": {"user": "'${POSTGRES_USER}'","password": "'${POSTGRES_PASSWORD}'","host": "'${your_ip}'", "port": "5442", "database": "'${POSTGRES_DB}'"}}' \
-H 'Content-Type: application/json'
```
2 changes: 2 additions & 0 deletions comps/text2sql/src/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
File renamed without changes.
2 changes: 2 additions & 0 deletions comps/text2sql/src/integrations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
122 changes: 122 additions & 0 deletions comps/text2sql/src/integrations/opea.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import os
from typing import Annotated, Optional

from langchain.agents.agent_types import AgentType
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_huggingface import HuggingFaceEndpoint
from pydantic import BaseModel, Field
from sqlalchemy import create_engine
from sqlalchemy.exc import SQLAlchemyError

from comps import CustomLogger, OpeaComponent, ServiceType
from comps.text2sql.src.integrations.sql_agent import CustomSQLDatabaseToolkit, custom_create_sql_agent

logger = CustomLogger("comps-text2sql")
logflag = os.getenv("LOGFLAG", False)

sql_params = {
"max_string_length": 3600,
}

generation_params = {
"max_new_tokens": 1024,
"top_k": 10,
"top_p": 0.95,
"temperature": 0.01,
"repetition_penalty": 1.03,
"streaming": True,
}

TGI_LLM_ENDPOINT = os.environ.get("TGI_LLM_ENDPOINT")

llm = HuggingFaceEndpoint(
endpoint_url=TGI_LLM_ENDPOINT,
task="text-generation",
**generation_params,
)


class PostgresConnection(BaseModel):
user: Annotated[str, Field(min_length=1)]
password: Annotated[str, Field(min_length=1)]
host: Annotated[str, Field(min_length=1)]
port: Annotated[int, Field(ge=1, le=65535)] # Default PostgreSQL port with constraints
database: Annotated[str, Field(min_length=1)]

def connection_string(self) -> str:
return f"postgresql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}"

def test_connection(self) -> bool:
"""Test the connection to the PostgreSQL database."""
connection_string = self.connection_string()
try:
engine = create_engine(connection_string)
with engine.connect() as _:
# If the connection is successful, return True
return True
except SQLAlchemyError as e:
print(f"Connection failed: {e}")
return False


class Input(BaseModel):
input_text: str
conn_str: Optional[PostgresConnection] = None


class OpeaText2SQL(OpeaComponent):
"""A specialized text to sql component derived from OpeaComponent for interacting with TGI services and Database.
Attributes:
client: An instance of the client for text to sql generation and execution.
"""

def __init__(self, name: str, description: str, config: dict = None):
super().__init__(name, ServiceType.TEXT2SQL.name.lower(), description, config)

async def check_health(self) -> bool:
"""Checks the health of the TGI service.
Returns:
bool: True if the service is reachable and healthy, False otherwise.
"""
try:
response = llm.generate(["Hello, how are you?"])
return True
except Exception as e:
return False

async def invoke(self, input: Input):
url = input.conn_str.connection_string()
"""Execute a SQL query using the custom SQL agent.
Args:
input (str): The user's input.
url (str): The URL of the database to connect to.
Returns:
dict: The result of the SQL execution.
"""
db = SQLDatabase.from_uri(url, **sql_params)
logger.info("Starting Agent")
agent_executor = custom_create_sql_agent(
llm=llm,
verbose=True,
toolkit=CustomSQLDatabaseToolkit(llm=llm, db=db),
agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
agent_executor_kwargs={"return_intermediate_steps": True},
)

result = await agent_executor.ainvoke(input)

query = []
for log, _ in result["intermediate_steps"]:
if log.tool == "sql_db_query":
query.append(log.tool_input)
result["sql"] = query[0].replace("Observation", "")
return result
Loading

0 comments on commit 2cfd014

Please sign in to comment.