Skip to content

Commit

Permalink
add sqlalchemy support and mongodb support
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyongchao committed Oct 24, 2024
1 parent bf7a7a7 commit 74f63cb
Show file tree
Hide file tree
Showing 10 changed files with 865 additions and 455 deletions.
70 changes: 37 additions & 33 deletions lazyllm/docs/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,7 +1002,7 @@
)

add_chinese_doc(
"SqlManager.reset_tables",
"SqlManager.reset_table_info_dict",
"""\
根据描述表结构的字典设置SqlManager所使用的数据表。注意:若表在数据库中不存在将会自动创建,若存在则会校验所有字段的一致性。
字典格式关键字示例如下。
Expand Down Expand Up @@ -1034,7 +1034,7 @@
)

add_english_doc(
"SqlManager.reset_tables",
"SqlManager.reset_table_info_dict",
"""\
Set the data tables used by SqlManager according to the dictionary describing the table structure.
Note that if the table does not exist in the database, it will be automatically created, and if it exists, all field consistencies will be checked.
Expand Down Expand Up @@ -1068,81 +1068,85 @@
)

add_chinese_doc(
"SqlManager.check_connection",
"SqlManagerBase",
"""\
检查当前SqlManager的连接状态
SqlManagerBase是与数据库进行交互的专用工具。它提供了连接数据库,设置、创建、检查数据表,插入数据,执行查询的方法
**Returns:**\n
- bool: 连接成功(True), 连接失败(False)
- str: 连接成功为"Success" 否则为具体的失败信息.
Arguments:
db_type (str): 目前仅支持"PostgreSQL",后续会增加"MySQL", "MS SQL"
user (str): username
password (str): password
host (str): 主机名或IP
port (int): 端口号
db_name (str): 数据仓库名
tables_info_dict (dict): 数据表的描述
options_str (str): k1=v1&k2=v2形式表示的选项设置
""",
)

add_english_doc(
"SqlManager.check_connection",
"SqlManagerBase",
"""\
Check the current connection status of the SqlManager.
SqlManagerBase is a specialized tool for interacting with databases.
It provides methods for creating tables, executing queries, and performing updates on databases.
**Returns:**\n
- bool: True if the connection is successful, False if it fails.
- str: "Success" if the connection is successful; otherwise, it provides specific failure information.
Arguments:
db_type (str): Currently only "PostgreSQL" is supported, with "MySQL" and "MS SQL" to be added later.
user (str): Username for connection
password (str): Password for connection
host (str): Hostname or IP
port (int): Port number
db_name (str): Name of the database
tables_info_dict (dict): Description of the data tables
options_str (str): Options represented in the format k1=v1&k2=v2
""",
)

add_chinese_doc(
"SqlManager.reset_tables",
"SqlManagerBase.check_connection",
"""\
根据提供的表结构设置数据库链接。
若数据库中已存在表项则检查一致性,否则创建数据表
Args:
tables_info_dict (dict): 数据表的描述
检查当前SqlManagerBase的连接状态。
**Returns:**\n
- bool: 设置成功(True), 设置失败(False)
- str: 设置成功为"Success" 否则为具体的失败信息.
- bool: 连接成功(True), 连接失败(False)
- str: 连接成功为"Success" 否则为具体的失败信息.
""",
)

add_english_doc(
"SqlManager.reset_tables",
"SqlManagerBase.check_connection",
"""\
Set database connection based on the provided table structure.
Check consistency if the table items already exist in the database, otherwise create the data table.
Args:
tables_info_dict (dict): Description of the data tables
Check the current connection status of the SqlManagerBase.
**Returns:**\n
- bool: True if set successfully, False if set failed
- str: "Success" if set successfully, otherwise specific failure information.
- bool: True if the connection is successful, False if it fails.
- str: "Success" if the connection is successful; otherwise, it provides specific failure information.
""",
)

add_chinese_doc(
"SqlManager.get_query_result_in_json",
"SqlManagerBase.execute_to_json",
"""\
执行SQL查询并返回JSON格式的结果。
""",
)

add_english_doc(
"SqlManager.get_query_result_in_json",
"SqlManagerBase.execute_to_json",
"""\
Executes a SQL query and returns the result in JSON format.
""",
)

add_chinese_doc(
"SqlManager.execute_sql_update",
"SqlManagerBase.execute",
"""\
在SQLite数据库上执行SQL插入或更新脚本。
""",
)

add_english_doc(
"SqlManager.execute_sql_update",
"SqlManagerBase.execute",
"""\
Execute insert or update script.
""",
Expand Down
7 changes: 6 additions & 1 deletion lazyllm/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
ReWOOAgent,
)
from .classifier import IntentClassifier
from .sql import SQLiteManger, SqlManager, SqlCall
from .sql import SqlManagerBase, SQLiteManger, SqlManager, MonogDBManager, DBResult, DBStatus, SqlCall

from .tools.http_tool import HttpTool

__all__ = [
Expand All @@ -28,8 +29,12 @@
"ReWOOAgent",
"IntentClassifier",
"SentenceSplitter",
"SqlManagerBase",
"SQLiteManger",
"SqlManager",
"MonogDBManager",
"DBResult",
"DBStatus",
"SqlCall",
"HttpTool",
]
6 changes: 4 additions & 2 deletions lazyllm/tools/sql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .sql_tool import SQLiteManger, SqlCall, SqlManager
from .sql_manager import SqlManager, SqlManagerBase, SQLiteManger, DBResult, DBStatus
from .mongodb_manager import MonogDBManager
from .sql_call import SqlCall

__all__ = ["SqlCall", "SQLiteManger", "SqlManager"]
__all__ = ["SqlCall", "SqlManagerBase", "SQLiteManger", "SqlManager", "MonogDBManager", "DBResult", "DBStatus"]
73 changes: 73 additions & 0 deletions lazyllm/tools/sql/db_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from enum import Enum, unique
from typing import List, Union
from pydantic import BaseModel
from abc import ABC, abstractmethod


@unique
class DBStatus(Enum):
SUCCESS = 0
FAIL = 1


class DBResult(BaseModel):
status: DBStatus = DBStatus.SUCCESS
detail: str = "Success"
result: Union[List, None] = None


class DBManager(ABC):
DB_TYPE_SUPPORTED = set(["postgresql", "mysql", "mssql", "sqlite", "mongodb"])
DB_DRIVER_MAP = {"mysql": "pymysql"}

def __init__(
self,
db_type: str,
user: str,
password: str,
host: str,
port: int,
db_name: str,
options_str: str = "",
) -> None:
db_result = self.reset_engine(db_type, user, password, host, port, db_name, options_str)
if db_result.status != DBStatus.SUCCESS:
raise ValueError(db_result.detail)

def reset_engine(self, db_type, user, password, host, port, db_name, options_str):
db_type_lower = db_type.lower()
self.status = DBStatus.SUCCESS
self.detail = ""
self.db_type = db_type_lower
if db_type_lower not in self.DB_TYPE_SUPPORTED:
return DBResult(status=DBStatus.FAIL, detail=f"{db_type} not supported")
if db_type_lower in self.DB_DRIVER_MAP:
conn_url = (
f"{db_type_lower}+{self.DB_DRIVER_MAP[db_type_lower]}://{user}:{password}@{host}:{port}/{db_name}"
)
else:
conn_url = f"{db_type_lower}://{user}:{password}@{host}:{port}/{db_name}"
self._conn_url = conn_url
self._desc = ""

@abstractmethod
def execute_to_json(self, statement):
pass

@property
def desc(self):
return self._desc

def _is_str_or_nested_dict(self, value):
if isinstance(value, str):
return True
elif isinstance(value, dict):
return all(self._is_str_or_nested_dict(v) for v in value.values())
return False

def _validate_desc(self, d):
return isinstance(d, dict) and all(self._is_str_or_nested_dict(v) for v in d.values())

def _serialize_uncommon_type(self, obj):
if not isinstance(obj, int, str, float, bool, tuple, list, dict):
return str(obj)
127 changes: 127 additions & 0 deletions lazyllm/tools/sql/mongodb_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import json
import pydantic
from pymongo import MongoClient
from .db_manager import DBManager, DBStatus, DBResult


class CollectionDesc(pydantic.BaseModel):
schema: dict
schema_desc: dict


class MonogDBManager(DBManager):
def __init__(self, user, password, host, port, db_name, collection_name, options_str=""):
result = self.reset_client(user, password, host, port, db_name, collection_name, options_str)
self.status, self.detail = result.status, result.detail
if self.status != DBStatus.SUCCESS:
raise ValueError(self.detail)

def reset_client(self, user, password, host, port, db_name, collection_name, options_str="") -> DBResult:
db_type_lower = "mongodb"
self.status = DBStatus.SUCCESS
self.detail = ""
conn_url = f"{db_type_lower}://{user}:{password}@{host}:{port}/"
self.conn_url = conn_url
self.db_name = db_name
self.collection_name = collection_name
if options_str:
extra_fields = {
key: value for key_value in options_str.split("&") for key, value in (key_value.split("="),)
}
self.extra_fields = extra_fields
self.client = MongoClient(self.conn_url)
result = self.check_connection()
self.collection = self.client[self.db_name][self.collection_name]
self._desc = {}
if result.status != DBStatus.SUCCESS:
return result
"""
if db_name not in self.client.list_database_names():
return DBResult(status=DBStatus.FAIL, detail=f"Database {db_name} not found")
if collection_name not in self.client[db_name].list_collection_names():
return DBResult(status=DBStatus.FAIL, detail=f"Collection {collection_name} not found")
"""
return DBResult()

def check_connection(self) -> DBResult:
try:
# check connection status
_ = self.client.server_info()
return DBResult()
except Exception as e:
return DBResult(status=DBStatus.FAIL, detail=str(e))

def drop_database(self) -> DBResult:
if self.status != DBStatus.SUCCESS:
return DBResult(status=self.status, detail=self.detail, result=None)
self.client.drop_database(self.db_name)
return DBResult()

def drop_collection(self) -> DBResult:
db = self.client[self.db_name]
db[self.collection_name].drop()
return DBResult()

def insert(self, statement):
if isinstance(statement, dict):
self.collection.insert_one(statement)
elif isinstance(statement, list):
self.collection.insert_many(statement)
else:
return DBResult(status=DBStatus.FAIL, detail=f"statement type {type(statement)} not supported", result=None)
return DBResult()

def update(self, filter: dict, value: dict, is_many: bool = True):
if is_many:
self.collection.update_many(filter, value)
else:
self.collection.update_one(filter, value)
return DBResult()

def delete(self, filter: dict, is_many: bool = True):
if is_many:
self.collection.delete_many(filter)
else:
self.collection.delete_one(filter)

def select(self, query, projection: dict[str, bool] = None, limit: int = None):
if limit is not None:
result = self.collection.find(query, projection)
else:
result = self.collection.find(query, projection).limit(limit)
return DBResult(result=list(result))

def execute(self, statement):
try:
pipeline_list = json.loads(statement)
result = self.collection.aggregate(pipeline_list)
return DBResult(result=list(result))
except Exception as e:
return DBResult(status=DBStatus.FAIL, detail=str(e))

def execute_to_json(self, statement) -> str:
dbresult = self.execute(statement)
if dbresult.status != DBStatus:
self.status, self.detail = dbresult.status, dbresult.detail
return ""
str_result = json.dumps(dbresult.result, ensure_ascii=False, default=self._serialize_uncommon_type)
return str_result

@property
def desc(self):
return self._desc

def set_desc(self, schema_and_desc: dict) -> DBResult:
self._desc = ""
try:
collection_desc = CollectionDesc.model_validate(schema_and_desc)
except pydantic.ValidationError as e:
return DBResult(status=DBStatus.FAIL, detail=str(e))
if not self._validate_desc(collection_desc.schema) or not self._validate_desc(collection_desc.schema_desc):
err_msg = "key and value in desc shoule be str or nested str dict"
return DBResult(status=DBStatus.FAIL, detail=err_msg)
self.desc = "Collection schema:\n"
self.desc += json.dumps(collection_desc.schema, ensure_ascii=False, indent=4)
self.desc += "Collection schema description:\n"
self.desc += json.dumps(collection_desc.schema, ensure_ascii=False, indent=4)
return DBResult()
Loading

0 comments on commit 74f63cb

Please sign in to comment.