From 74f63cb8270f6c60bbd039dd536583e57ba1f178 Mon Sep 17 00:00:00 2001 From: zhangyongchao Date: Thu, 24 Oct 2024 18:38:34 +0800 Subject: [PATCH] add sqlalchemy support and mongodb support --- lazyllm/docs/tools.py | 70 ++--- lazyllm/tools/__init__.py | 7 +- lazyllm/tools/sql/__init__.py | 6 +- lazyllm/tools/sql/db_manager.py | 73 +++++ lazyllm/tools/sql/mongodb_manager.py | 127 ++++++++ lazyllm/tools/sql/sql_call.py | 151 ++++++++++ lazyllm/tools/sql/sql_manager.py | 424 +++++++++++++++++++++++++++ lazyllm/tools/sql/sql_tool.py | 382 ------------------------ tests/charge_tests/test_engine.py | 11 +- tests/charge_tests/test_sql_tool.py | 69 ++--- 10 files changed, 865 insertions(+), 455 deletions(-) create mode 100644 lazyllm/tools/sql/db_manager.py create mode 100644 lazyllm/tools/sql/mongodb_manager.py create mode 100644 lazyllm/tools/sql/sql_call.py create mode 100644 lazyllm/tools/sql/sql_manager.py delete mode 100644 lazyllm/tools/sql/sql_tool.py diff --git a/lazyllm/docs/tools.py b/lazyllm/docs/tools.py index 6e335870..c581354d 100644 --- a/lazyllm/docs/tools.py +++ b/lazyllm/docs/tools.py @@ -1002,7 +1002,7 @@ ) add_chinese_doc( - "SqlManager.reset_tables", + "SqlManager.reset_table_info_dict", """\ 根据描述表结构的字典设置SqlManager所使用的数据表。注意:若表在数据库中不存在将会自动创建,若存在则会校验所有字段的一致性。 字典格式关键字示例如下。 @@ -1034,7 +1034,7 @@ ) add_english_doc( - "SqlManager.reset_tables", + "SqlManager.reset_table_info_dict", """\ Set the data tables used by SqlManager according to the dictionary describing the table structure. Note that if the table does not exist in the database, it will be automatically created, and if it exists, all field consistencies will be checked. @@ -1068,81 +1068,85 @@ ) add_chinese_doc( - "SqlManager.check_connection", + "SqlManagerBase", """\ -检查当前SqlManager的连接状态。 +SqlManagerBase是与数据库进行交互的专用工具。它提供了连接数据库,设置、创建、检查数据表,插入数据,执行查询的方法。 -**Returns:**\n -- bool: 连接成功(True), 连接失败(False) -- str: 连接成功为"Success" 否则为具体的失败信息. +Arguments: + db_type (str): 目前仅支持"PostgreSQL",后续会增加"MySQL", "MS SQL" + user (str): username + password (str): password + host (str): 主机名或IP + port (int): 端口号 + db_name (str): 数据仓库名 + tables_info_dict (dict): 数据表的描述 + options_str (str): k1=v1&k2=v2形式表示的选项设置 """, ) add_english_doc( - "SqlManager.check_connection", + "SqlManagerBase", """\ -Check the current connection status of the SqlManager. +SqlManagerBase is a specialized tool for interacting with databases. +It provides methods for creating tables, executing queries, and performing updates on databases. -**Returns:**\n -- bool: True if the connection is successful, False if it fails. -- str: "Success" if the connection is successful; otherwise, it provides specific failure information. +Arguments: + db_type (str): Currently only "PostgreSQL" is supported, with "MySQL" and "MS SQL" to be added later. + user (str): Username for connection + password (str): Password for connection + host (str): Hostname or IP + port (int): Port number + db_name (str): Name of the database + tables_info_dict (dict): Description of the data tables + options_str (str): Options represented in the format k1=v1&k2=v2 """, ) add_chinese_doc( - "SqlManager.reset_tables", + "SqlManagerBase.check_connection", """\ -根据提供的表结构设置数据库链接。 -若数据库中已存在表项则检查一致性,否则创建数据表 - -Args: - tables_info_dict (dict): 数据表的描述 +检查当前SqlManagerBase的连接状态。 **Returns:**\n -- bool: 设置成功(True), 设置失败(False) -- str: 设置成功为"Success" 否则为具体的失败信息. +- bool: 连接成功(True), 连接失败(False) +- str: 连接成功为"Success" 否则为具体的失败信息. """, ) add_english_doc( - "SqlManager.reset_tables", + "SqlManagerBase.check_connection", """\ -Set database connection based on the provided table structure. -Check consistency if the table items already exist in the database, otherwise create the data table. - -Args: - tables_info_dict (dict): Description of the data tables +Check the current connection status of the SqlManagerBase. **Returns:**\n -- bool: True if set successfully, False if set failed -- str: "Success" if set successfully, otherwise specific failure information. - +- bool: True if the connection is successful, False if it fails. +- str: "Success" if the connection is successful; otherwise, it provides specific failure information. """, ) add_chinese_doc( - "SqlManager.get_query_result_in_json", + "SqlManagerBase.execute_to_json", """\ 执行SQL查询并返回JSON格式的结果。 """, ) add_english_doc( - "SqlManager.get_query_result_in_json", + "SqlManagerBase.execute_to_json", """\ Executes a SQL query and returns the result in JSON format. """, ) add_chinese_doc( - "SqlManager.execute_sql_update", + "SqlManagerBase.execute", """\ 在SQLite数据库上执行SQL插入或更新脚本。 """, ) add_english_doc( - "SqlManager.execute_sql_update", + "SqlManagerBase.execute", """\ Execute insert or update script. """, diff --git a/lazyllm/tools/__init__.py b/lazyllm/tools/__init__.py index 0df3c274..cca15c1f 100644 --- a/lazyllm/tools/__init__.py +++ b/lazyllm/tools/__init__.py @@ -10,7 +10,8 @@ ReWOOAgent, ) from .classifier import IntentClassifier -from .sql import SQLiteManger, SqlManager, SqlCall +from .sql import SqlManagerBase, SQLiteManger, SqlManager, MonogDBManager, DBResult, DBStatus, SqlCall + from .tools.http_tool import HttpTool __all__ = [ @@ -28,8 +29,12 @@ "ReWOOAgent", "IntentClassifier", "SentenceSplitter", + "SqlManagerBase", "SQLiteManger", "SqlManager", + "MonogDBManager", + "DBResult", + "DBStatus", "SqlCall", "HttpTool", ] diff --git a/lazyllm/tools/sql/__init__.py b/lazyllm/tools/sql/__init__.py index d6396096..fcc0f306 100644 --- a/lazyllm/tools/sql/__init__.py +++ b/lazyllm/tools/sql/__init__.py @@ -1,3 +1,5 @@ -from .sql_tool import SQLiteManger, SqlCall, SqlManager +from .sql_manager import SqlManager, SqlManagerBase, SQLiteManger, DBResult, DBStatus +from .mongodb_manager import MonogDBManager +from .sql_call import SqlCall -__all__ = ["SqlCall", "SQLiteManger", "SqlManager"] +__all__ = ["SqlCall", "SqlManagerBase", "SQLiteManger", "SqlManager", "MonogDBManager", "DBResult", "DBStatus"] diff --git a/lazyllm/tools/sql/db_manager.py b/lazyllm/tools/sql/db_manager.py new file mode 100644 index 00000000..eefc3764 --- /dev/null +++ b/lazyllm/tools/sql/db_manager.py @@ -0,0 +1,73 @@ +from enum import Enum, unique +from typing import List, Union +from pydantic import BaseModel +from abc import ABC, abstractmethod + + +@unique +class DBStatus(Enum): + SUCCESS = 0 + FAIL = 1 + + +class DBResult(BaseModel): + status: DBStatus = DBStatus.SUCCESS + detail: str = "Success" + result: Union[List, None] = None + + +class DBManager(ABC): + DB_TYPE_SUPPORTED = set(["postgresql", "mysql", "mssql", "sqlite", "mongodb"]) + DB_DRIVER_MAP = {"mysql": "pymysql"} + + def __init__( + self, + db_type: str, + user: str, + password: str, + host: str, + port: int, + db_name: str, + options_str: str = "", + ) -> None: + db_result = self.reset_engine(db_type, user, password, host, port, db_name, options_str) + if db_result.status != DBStatus.SUCCESS: + raise ValueError(db_result.detail) + + def reset_engine(self, db_type, user, password, host, port, db_name, options_str): + db_type_lower = db_type.lower() + self.status = DBStatus.SUCCESS + self.detail = "" + self.db_type = db_type_lower + if db_type_lower not in self.DB_TYPE_SUPPORTED: + return DBResult(status=DBStatus.FAIL, detail=f"{db_type} not supported") + if db_type_lower in self.DB_DRIVER_MAP: + conn_url = ( + f"{db_type_lower}+{self.DB_DRIVER_MAP[db_type_lower]}://{user}:{password}@{host}:{port}/{db_name}" + ) + else: + conn_url = f"{db_type_lower}://{user}:{password}@{host}:{port}/{db_name}" + self._conn_url = conn_url + self._desc = "" + + @abstractmethod + def execute_to_json(self, statement): + pass + + @property + def desc(self): + return self._desc + + def _is_str_or_nested_dict(self, value): + if isinstance(value, str): + return True + elif isinstance(value, dict): + return all(self._is_str_or_nested_dict(v) for v in value.values()) + return False + + def _validate_desc(self, d): + return isinstance(d, dict) and all(self._is_str_or_nested_dict(v) for v in d.values()) + + def _serialize_uncommon_type(self, obj): + if not isinstance(obj, int, str, float, bool, tuple, list, dict): + return str(obj) diff --git a/lazyllm/tools/sql/mongodb_manager.py b/lazyllm/tools/sql/mongodb_manager.py new file mode 100644 index 00000000..860d4599 --- /dev/null +++ b/lazyllm/tools/sql/mongodb_manager.py @@ -0,0 +1,127 @@ +import json +import pydantic +from pymongo import MongoClient +from .db_manager import DBManager, DBStatus, DBResult + + +class CollectionDesc(pydantic.BaseModel): + schema: dict + schema_desc: dict + + +class MonogDBManager(DBManager): + def __init__(self, user, password, host, port, db_name, collection_name, options_str=""): + result = self.reset_client(user, password, host, port, db_name, collection_name, options_str) + self.status, self.detail = result.status, result.detail + if self.status != DBStatus.SUCCESS: + raise ValueError(self.detail) + + def reset_client(self, user, password, host, port, db_name, collection_name, options_str="") -> DBResult: + db_type_lower = "mongodb" + self.status = DBStatus.SUCCESS + self.detail = "" + conn_url = f"{db_type_lower}://{user}:{password}@{host}:{port}/" + self.conn_url = conn_url + self.db_name = db_name + self.collection_name = collection_name + if options_str: + extra_fields = { + key: value for key_value in options_str.split("&") for key, value in (key_value.split("="),) + } + self.extra_fields = extra_fields + self.client = MongoClient(self.conn_url) + result = self.check_connection() + self.collection = self.client[self.db_name][self.collection_name] + self._desc = {} + if result.status != DBStatus.SUCCESS: + return result + """ + if db_name not in self.client.list_database_names(): + return DBResult(status=DBStatus.FAIL, detail=f"Database {db_name} not found") + if collection_name not in self.client[db_name].list_collection_names(): + return DBResult(status=DBStatus.FAIL, detail=f"Collection {collection_name} not found") + """ + return DBResult() + + def check_connection(self) -> DBResult: + try: + # check connection status + _ = self.client.server_info() + return DBResult() + except Exception as e: + return DBResult(status=DBStatus.FAIL, detail=str(e)) + + def drop_database(self) -> DBResult: + if self.status != DBStatus.SUCCESS: + return DBResult(status=self.status, detail=self.detail, result=None) + self.client.drop_database(self.db_name) + return DBResult() + + def drop_collection(self) -> DBResult: + db = self.client[self.db_name] + db[self.collection_name].drop() + return DBResult() + + def insert(self, statement): + if isinstance(statement, dict): + self.collection.insert_one(statement) + elif isinstance(statement, list): + self.collection.insert_many(statement) + else: + return DBResult(status=DBStatus.FAIL, detail=f"statement type {type(statement)} not supported", result=None) + return DBResult() + + def update(self, filter: dict, value: dict, is_many: bool = True): + if is_many: + self.collection.update_many(filter, value) + else: + self.collection.update_one(filter, value) + return DBResult() + + def delete(self, filter: dict, is_many: bool = True): + if is_many: + self.collection.delete_many(filter) + else: + self.collection.delete_one(filter) + + def select(self, query, projection: dict[str, bool] = None, limit: int = None): + if limit is not None: + result = self.collection.find(query, projection) + else: + result = self.collection.find(query, projection).limit(limit) + return DBResult(result=list(result)) + + def execute(self, statement): + try: + pipeline_list = json.loads(statement) + result = self.collection.aggregate(pipeline_list) + return DBResult(result=list(result)) + except Exception as e: + return DBResult(status=DBStatus.FAIL, detail=str(e)) + + def execute_to_json(self, statement) -> str: + dbresult = self.execute(statement) + if dbresult.status != DBStatus: + self.status, self.detail = dbresult.status, dbresult.detail + return "" + str_result = json.dumps(dbresult.result, ensure_ascii=False, default=self._serialize_uncommon_type) + return str_result + + @property + def desc(self): + return self._desc + + def set_desc(self, schema_and_desc: dict) -> DBResult: + self._desc = "" + try: + collection_desc = CollectionDesc.model_validate(schema_and_desc) + except pydantic.ValidationError as e: + return DBResult(status=DBStatus.FAIL, detail=str(e)) + if not self._validate_desc(collection_desc.schema) or not self._validate_desc(collection_desc.schema_desc): + err_msg = "key and value in desc shoule be str or nested str dict" + return DBResult(status=DBStatus.FAIL, detail=err_msg) + self.desc = "Collection schema:\n" + self.desc += json.dumps(collection_desc.schema, ensure_ascii=False, indent=4) + self.desc += "Collection schema description:\n" + self.desc += json.dumps(collection_desc.schema, ensure_ascii=False, indent=4) + return DBResult() diff --git a/lazyllm/tools/sql/sql_call.py b/lazyllm/tools/sql/sql_call.py new file mode 100644 index 00000000..59945b69 --- /dev/null +++ b/lazyllm/tools/sql/sql_call.py @@ -0,0 +1,151 @@ +from lazyllm.module import ModuleBase +from lazyllm.components import ChatPrompter +from lazyllm.tools.utils import chat_history_to_str +from lazyllm import pipeline, globals, bind, _0, switch +from typing import List, Any, Dict, Union, Callable +import datetime +import re +from .db_manager import DBManager + +sql_query_instruct_template = """ +Given the following SQL tables and current date {current_date}, your job is to write sql queries in {db_type} given a user’s request. + +{desc} + +Alert: Just replay the sql query in a code block start with keyword "sql" +""" # noqa E501 + + +mongodb_query_instruct_template = """ +Current date is {current_date}. +You are a seasoned expert with 10 years of experience in crafting NoSQL queries for {db_type}. +I will provide a collection description in a specified format. +Your task is to analyze the user_question, which follows certain guidelines, and generate a NoSQL MongoDB aggregation pipeline accordingly. + +{desc} + +Note: Please return the query in a code block and start with the keyword "mongodb". +""" # noqa E501 + +db_explain_instruct_template = """ +According to chat history +``` +{{history_info}} +``` + +bellowing {statement_type} is executed + +``` +{{query}} +``` +the result is +``` +{{result}} +``` +""" + + +class SqlCall(ModuleBase): + EXAMPLE_TITLE = "Here are some example: " + + def __init__( + self, + llm, + sql_manager: DBManager, + sql_examples: str = "", + sql_post_func: Callable = None, + use_llm_for_sql_result=True, + return_trace: bool = False, + ) -> None: + super().__init__(return_trace=return_trace) + self._sql_tool = sql_manager + self.sql_post_func = sql_post_func + + if sql_manager.db_type == "mongodb": + self._query_prompter = ChatPrompter(instruction=mongodb_query_instruct_template).pre_hook( + self.sql_query_promt_hook + ) + statement_type = "mongodb pipeline" + else: + self._query_prompter = ChatPrompter(instruction=sql_query_instruct_template).pre_hook( + self.sql_query_promt_hook + ) + statement_type = "sql query" + + self._llm_query = llm.share(prompt=self._query_prompter) + self._answer_prompter = ChatPrompter( + instruction=db_explain_instruct_template.format(statement_type=statement_type) + ).pre_hook(self.sql_explain_prompt_hook) + self._llm_answer = llm.share(prompt=self._answer_prompter) + self._pattern = re.compile(r"```sql(.+?)```", re.DOTALL) + self.example = sql_examples + with pipeline() as sql_execute_ppl: + sql_execute_ppl.exec = self._sql_tool.execute_to_json + if use_llm_for_sql_result: + sql_execute_ppl.concate = (lambda q, r: [q, r]) | bind(sql_execute_ppl.input, _0) + sql_execute_ppl.llm_answer = self._llm_answer + with pipeline() as ppl: + ppl.llm_query = self._llm_query + ppl.sql_extractor = self.extract_sql_from_response + with switch(judge_on_full_input=False) as ppl.sw: + ppl.sw.case[False, lambda x: x] + ppl.sw.case[True, sql_execute_ppl] + self._impl = ppl + + def sql_query_promt_hook( + self, + input: Union[str, List, Dict[str, str], None] = None, + history: List[Union[List[str], Dict[str, Any]]] = [], + tools: Union[List[Dict[str, Any]], None] = None, + label: Union[str, None] = None, + ): + current_date = datetime.datetime.now().strftime("%Y-%m-%d") + schema_desc = self._sql_tool.desc + if self.example: + schema_desc += f"\n{self.EXAMPLE_TITLE}\n{self.example}\n" + if not isinstance(input, str): + raise ValueError(f"Unexpected type for input: {type(input)}") + return ( + dict(current_date=current_date, db_type=self._sql_tool.db_type, desc=schema_desc, user_query=input), + history, + tools, + label, + ) + + def sql_explain_prompt_hook( + self, + input: Union[str, List, Dict[str, str], None] = None, + history: List[Union[List[str], Dict[str, Any]]] = [], + tools: Union[List[Dict[str, Any]], None] = None, + label: Union[str, None] = None, + ): + explain_query = "Tell the user based on the execution results, making sure to keep the language consistent \ + with the user's input and don't translate original result." + if not isinstance(input, list) and len(input) != 2: + raise ValueError(f"Unexpected type for input: {type(input)}") + assert "root_input" in globals and self._llm_answer._module_id in globals["root_input"] + user_query = globals["root_input"][self._llm_answer._module_id] + globals.pop("root_input") + history_info = chat_history_to_str(history, user_query) + return ( + dict(history_info=history_info, query=input[0], result=input[1], explain_query=explain_query), + history, + tools, + label, + ) + + def extract_sql_from_response(self, str_response: str) -> tuple[bool, str]: + # Remove the triple backticks if present + matches = self._pattern.findall(str_response) + if matches: + # Return the first match + extracted_content = matches[0].strip() + return True, extracted_content if not self.sql_post_func else self.sql_post_func(extracted_content) + else: + return False, str_response + + def forward(self, input: str, llm_chat_history: List[Dict[str, Any]] = None): + globals["root_input"] = {self._llm_answer._module_id: input} + if self._module_id in globals["chat_history"]: + globals["chat_history"][self._llm_query._module_id] = globals["chat_history"][self._module_id] + return self._impl(input) diff --git a/lazyllm/tools/sql/sql_manager.py b/lazyllm/tools/sql/sql_manager.py new file mode 100644 index 00000000..fb5745c1 --- /dev/null +++ b/lazyllm/tools/sql/sql_manager.py @@ -0,0 +1,424 @@ +import json +from typing import Union +import sqlalchemy +from sqlalchemy.exc import SQLAlchemyError, OperationalError, ProgrammingError +from sqlalchemy.orm import declarative_base, DeclarativeMeta +import pydantic +from .db_manager import DBManager, DBStatus, DBResult +from pathlib import Path + + +class SqlManagerBase(DBManager): + def __init__( + self, + db_type: str, + user: str, + password: str, + host: str, + port: int, + db_name: str, + options_str: str = "", + ) -> None: + db_result = self.reset_engine(db_type, user, password, host, port, db_name, options_str) + if db_result.status != DBStatus.SUCCESS: + raise ValueError(self.detail) + + def reset_engine( + self, + db_type: str, + user: str, + password: str, + host: str, + port: int, + db_name: str, + options_str: str = "", + ): + super().reset_engine(db_type, user, password, host, port, db_name, options_str) + self._engine = sqlalchemy.create_engine(self._conn_url) + self._desc = "" + extra_fields = {} + if options_str: + extra_fields = { + key: value for key_value in options_str.split("&") for key, value in (key_value.split("="),) + } + self._extra_fields = extra_fields + result = self.check_connection() + if result.status != DBStatus.SUCCESS: + return result + db_result = self.get_all_tables() + if result.status != DBStatus.SUCCESS: + return result + self._visible_tables = db_result.result + return self.set_desc() + + def get_all_tables(self) -> DBResult: + inspector = sqlalchemy.inspect(self._engine) + table_names = inspector.get_table_names(schema=self._extra_fields.get("schema", None)) + if self.status != DBStatus.SUCCESS: + return DBResult(status=self.status, detail=self.detail, result=None) + return DBResult(result=table_names) + + def check_connection(self) -> DBResult: + try: + with self._engine.connect() as _: + return DBResult() + except SQLAlchemyError as e: + return DBResult(status=DBStatus.FAIL, detail=str(e)) + + @property + def visible_tables(self): + return self._visible_tables + + def set_visible_tables(self, tables: list[str]) -> DBResult: + db_result = self.get_all_tables() + if db_result.status != DBStatus.SUCCESS: + return db_result + all_tables_in_db = set(db_result.result) + visible_tables = [] + failed_tables = [] + for ele in tables: + if ele in all_tables_in_db: + visible_tables.append(ele) + else: + failed_tables.append(ele) + if len(tables) != len(visible_tables): + db_result = DBResult(status=DBStatus.FAIL, detail=f"{failed_tables} missing in database") + else: + db_result = DBResult() + self._visible_tables = visible_tables + return db_result + + def _get_table_columns(self, table_name: str): + inspector = sqlalchemy.inspect(self._engine) + columns = inspector.get_columns(table_name, schema=self._extra_fields.get("schema", None)) + return columns + + def set_desc(self, tables_desc: dict = {}) -> DBResult: + self._desc = "" + if not isinstance(tables_desc, dict): + return DBResult(status=DBStatus.FAIL, detail=f"desc type {type(tables_desc)} not supported") + if len(tables_desc) == 0: + return DBResult(status=DBStatus.FAIL, detail="Empty desc") + if len(self.visible_tables) == 0: + return DBResult() + self._desc = "The tables description is as follows\n```\n" + for table_name in self.visible_tables: + self._desc += f"Table {table_name}\n(\n" + table_columns = self._get_table_columns(table_name) + for i, column in enumerate(table_columns): + self._desc += f" {column['name']} {column['type']}" + if i != len(table_columns) - 1: + self._desc += "," + self._desc += "\n" + self._desc += ");\n" + if table_name in tables_desc: + self._desc += tables_desc[table_name] + "\n\n" + self._desc += "```\n" + return DBResult() + + @property + def desc(self) -> str: + return self._desc + + def execute(self, statement) -> DBResult: + if isinstance(statement, str): + statement = sqlalchemy.text(statement) + if isinstance( + statement, + (sqlalchemy.TextClause, sqlalchemy.Select, sqlalchemy.Insert, sqlalchemy.Update, sqlalchemy.Delete), + ): + status = DBStatus.SUCCESS + detail = "" + result = None + try: + with self._engine.connect() as conn: + cursor_result = conn.execute(statement) + conn.commit() + if cursor_result.returns_rows: + columns = list(cursor_result.keys()) + result = [dict(zip(columns, row)) for row in cursor_result] + except OperationalError as e: + status = DBStatus.FAIL + detail = f"ERROR: {str(e)}" + finally: + if "conn" in locals(): + conn.close() + return DBResult(status=status, detail=detail, result=result) + else: + return DBResult(status=DBStatus.FAIL, detail="statement type not supported") + + def execute_to_json(self, statement) -> str: + dbresult = self.execute(statement) + if dbresult.status != DBStatus.SUCCESS: + self.status, self.detail = dbresult.status, dbresult.detail + return "" + if dbresult.result is None: + return "" + str_result = json.dumps(dbresult.result, ensure_ascii=False, default=self._serialize_uncommon_type) + return str_result + + def _create_by_script(self, table: str) -> DBResult: + status = DBStatus.SUCCESS + detail = "Success" + try: + with self._engine.connect() as conn: + conn.execute(sqlalchemy.text(table)) + conn.commit() + except OperationalError as e: + status = DBStatus.FAIL + detail = f"ERROR: {str(e)}" + finally: + if "conn" in locals(): + conn.close() + return DBResult(status=status, detail=detail) + + def _match_exist_table(self, table: DeclarativeMeta) -> DBResult: + status = DBStatus.SUCCESS + detail = f"Table {table.__tablename__} already exists." + metadata = sqlalchemy.MetaData() + exist_table = sqlalchemy.Table(table.__tablename__, metadata, autoload_with=self._engine) + if len(table.__table__.columns) != len(exist_table.columns): + status = DBStatus.FAIL + detail += ( + f"\n Column number mismatch: {len(table.__table__.columns)} VS " f"{len(exist_table.columns)}(exists)" + ) + return DBResult(status=status, detail=detail) + for exist_column in exist_table.columns: + target_column = getattr(table, exist_column.name) + exist_type = type(exist_column.type).__visit_name__.lower() + target_type = type(target_column.type).__visit_name__.lower() + if target_type is not sqlalchemy.types.TypeEngine and exist_type != target_type: + detail += f"type mismatch {exist_type} vs {target_type}" + return DBResult(status=DBStatus.FAIL, detail=detail) + for attr in ["primary_key", "nullable"]: + if getattr(exist_column, attr) != getattr(target_column, attr): + detail += f"{attr} mismatch {getattr(exist_column, attr)} vs {getattr(target_column, attr)}" + return DBResult(status=DBStatus.FAIL, detail=detail) + return DBResult() + + def _create_by_api(self, table: DeclarativeMeta) -> DBResult: + try: + table.__table__.create(bind=self._engine) + return DBResult() + except ProgrammingError as e: + if "already exists" in str(e): + return self._match_exist_table(table) + + def create(self, table: Union[str, DeclarativeMeta]) -> DBResult: + status = DBStatus.SUCCESS + detail = "Success" + if isinstance(table, str): + return self._create_by_script(table) + elif isinstance(table, DeclarativeMeta): + return self._create_by_api(table) + else: + status = DBStatus.FAIL + detail += "\n Unsupported Type: {table}" + return DBResult(status=status, detail=detail) + + def drop(self, table) -> DBResult: + metadata = sqlalchemy.MetaData() + if isinstance(table, str): + tablename = table + elif isinstance(table, DeclarativeMeta): + tablename = table.__tablename__ + else: + return DBResult(status=DBStatus.FAIL, detail=f"{table} type unsupported") + Table = sqlalchemy.Table(tablename, metadata, autoload_with=self._engine) + Table.drop(self._engine, checkfirst=True) + return DBResult() + + def insert(self, statement) -> DBResult: + if isinstance(statement, (str, sqlalchemy.Select)): + return self.execute(statement) + elif isinstance(statement, dict): + table_name = statement.get("table_name", None) + table_data = statement.get("table_data", []) + returning = statement.get("returning", []) + if not table_name: + return DBResult(status=DBStatus.FAIL, detail="No table_name found") + if not table_data: + return DBResult(status=DBStatus.FAIL, detail="No table_data found") + metadata = sqlalchemy.MetaData() + table = sqlalchemy.Table(table_name, metadata, autoload_with=self._engine) + if not returning: + statement = sqlalchemy.insert(table).values(table_data) + else: + return_columns = [sqlalchemy.column(ele) for ele in returning] + statement = (sqlalchemy.insert(table).values(table_data)).returning(*return_columns) + return self.execute(statement) + else: + return DBResult(status=DBStatus.FAIL, detail="statement type not supported") + + def update(self, statement) -> DBResult: + if isinstance(statement, (str, sqlalchemy.Update)): + return self.execute(statement) + else: + return DBResult(status=DBStatus.FAIL, detail="statement type not supported") + + def delete(self, statement) -> DBResult: + if isinstance(statement, (str, sqlalchemy.Delete)): + if isinstance(statement, str): + tmp = statement.rstrip() + if len(tmp.split()) == 1: + statement = f"DELETE FROM {tmp}" + return self.execute(statement) + else: + return DBResult(status=DBStatus.FAIL, detail="statement type not supported") + + def select(self, statement) -> DBResult: + if isinstance(statement, (str, sqlalchemy.Select)): + return self.execute(statement) + else: + return DBResult(status=DBStatus.FAIL, detail="statement type not supported") + + +class ColumnInfo(pydantic.BaseModel): + name: str + data_type: str + comment: str = "" + # At least one column should be True + is_primary_key: bool = False + nullable: bool = True + + +class TableInfo(pydantic.BaseModel): + name: str + comment: str = "" + columns: list[ColumnInfo] + + +class TablesInfo(pydantic.BaseModel): + tables: list[TableInfo] + + +class SqlManager(SqlManagerBase): + PYTYPE_TO_SQL_MAP = { + "integer": sqlalchemy.Integer, + "string": sqlalchemy.Text, + "text": sqlalchemy.Text, + "boolean": sqlalchemy.Boolean, + "float": sqlalchemy.DOUBLE_PRECISION, + "datetime": sqlalchemy.DateTime, + "bytes": sqlalchemy.LargeBinary, + "bool": sqlalchemy.Boolean, + "date": sqlalchemy.Date, + "time": sqlalchemy.Time, + "list": sqlalchemy.ARRAY, + "dict": sqlalchemy.JSON, + "uuid": sqlalchemy.Uuid, + "any": sqlalchemy.types.TypeEngine, + } + + def __init__( + self, + db_type: str, + user: str, + password: str, + host: str, + port: int, + db_name: str, + tables_info_dict: dict, + options_str: str = "", + ) -> None: + self.reset_engine(db_type, user, password, host, port, db_name, tables_info_dict, options_str) + + def reset_engine(self, db_type, user, password, host, port, db_name, tables_info_dict, options_str): + self._tables_info_dict = {} + super().reset_engine(db_type, user, password, host, port, db_name, options_str) + db_result = self.reset_table_info_dict(tables_info_dict) + self.status = db_result.status + self.detail = db_result.detail + if self.status != DBStatus.SUCCESS: + raise ValueError(self.detail) + + def reset_table_info_dict(self, tables_info_dict: dict) -> DBResult: + self.status = DBStatus.SUCCESS + self.detail = "Success" + self._tables_info_dict = tables_info_dict + try: + tables_info = TablesInfo.model_validate(tables_info_dict) + except pydantic.ValidationError as e: + self.status, self.detail = DBStatus.FAIL, str(e) + return DBResult(status=DBStatus.FAIL, detail=str(e)) + # Create or Check tables + created_tables = [] + for table_info in tables_info.tables: + TableClass = self._create_table_cls(table_info) + db_result = self.create(TableClass) + if db_result.status != DBStatus.SUCCESS: + # drop partial created table + for created_table in created_tables: + self.drop(created_table) + return db_result + created_tables.append(TableClass) + + db_result = self.set_visible_tables([ele.__tablename__ for ele in created_tables]) + if db_result.status != DBStatus.SUCCESS: + return db_result + return self.set_desc() + + def _create_table_cls(self, table_info: TableInfo) -> DeclarativeMeta: + Base = declarative_base() + attrs = {"__tablename__": table_info.name} + for column_info in table_info.columns: + column_type = column_info.data_type.lower() + is_nullable = column_info.nullable + column_name = column_info.name + is_primary = column_info.is_primary_key + real_type = self.PYTYPE_TO_SQL_MAP[column_type] + attrs[column_name] = sqlalchemy.Column(real_type, nullable=is_nullable, primary_key=is_primary) + TableClass = type(table_info.name.capitalize(), (Base,), attrs) + return TableClass + + def set_desc(self) -> DBResult: + self._desc = "" + if not self._tables_info_dict: + return DBResult() + try: + tables_info = TablesInfo.model_validate(self._tables_info_dict) + except pydantic.ValidationError as e: + self.status, self.detail = DBStatus.FAIL, str(e) + return DBResult(status=DBStatus.FAIL, detail=str(e)) + self._desc = "The tables description is as follows\n```\n" + for table_info in tables_info.tables: + self._desc += f'Table "{table_info.name}"' + if table_info.comment: + self._desc += f' comment "{table_info.comment}"' + self._desc += "\n(\n" + real_columns = self._get_table_columns(table_info.name) + column_type_dict = {} + for real_column in real_columns: + column_type_dict[real_column["name"]] = real_column["type"] + for i, column_info in enumerate(table_info.columns): + self._desc += f"{column_info.name} {column_type_dict[column_info.name]}" + if column_info.comment: + self._desc += f' comment "{column_info.comment}"' + if i != len(table_info.columns) - 1: + self._desc += "," + self._desc += "\n" + self._desc += ");\n" + self._desc += "```\n" + return DBResult() + + +class SQLiteManger(SqlManager): + def __init__(self, db_file: str, tables_info_dict: dict = {}): + result = self.reset_engine(db_file, tables_info_dict) + self.status, self.detail = result.status, result.detail + if self.status != DBStatus.SUCCESS: + raise ValueError(self.detail) + + def reset_engine(self, db_file: str, tables_info_dict: dict): + self.db_type = "sqlite" + self.status = DBStatus.SUCCESS + self.detail = "" + if not Path(db_file).is_file(): + with Path(db_file).open("w") as _: + pass + if not Path(db_file).is_file(): + return DBResult(status=DBStatus.FAIL, detail=f"Create file {db_file} failed") + self._conn_url = f"sqlite:///{db_file}" + self._extra_fields = {} + self._engine = sqlalchemy.create_engine(self._conn_url) + return self.reset_table_info_dict(tables_info_dict) diff --git a/lazyllm/tools/sql/sql_tool.py b/lazyllm/tools/sql/sql_tool.py deleted file mode 100644 index ebb94de4..00000000 --- a/lazyllm/tools/sql/sql_tool.py +++ /dev/null @@ -1,382 +0,0 @@ -from lazyllm.module import ModuleBase -import lazyllm -from lazyllm.components import ChatPrompter -from lazyllm.tools.utils import chat_history_to_str -from lazyllm import pipeline, globals, bind, _0, switch -import json -from typing import List, Any, Dict, Union -from pathlib import Path -import datetime -import re -import sqlalchemy -from sqlalchemy.exc import SQLAlchemyError, OperationalError -from sqlalchemy.orm import declarative_base -import pydantic - - -class ColumnInfo(pydantic.BaseModel): - name: str - data_type: str - comment: str = "" - # At least one column should be True - is_primary_key: bool = False - nullable: bool = True - - -class TableInfo(pydantic.BaseModel): - name: str - comment: str = "" - columns: list[ColumnInfo] - - -class TablesInfo(pydantic.BaseModel): - tables: list[TableInfo] - - -class SqlManager: - DB_TYPE_SUPPORTED = set(["PostgreSQL", "MySQL", "MSSQL", "SQLite"]) - SUPPORTED_DATA_TYPES = { - "integer": sqlalchemy.Integer, - "string": sqlalchemy.String, - "boolean": sqlalchemy.Boolean, - "float": sqlalchemy.Float, - } - - def __init__( - self, - db_type: str, - user: str, - password: str, - host: str, - port: int, - db_name: str, - tables_info_dict: dict, - options_str: str = "", - ) -> None: - conn_url = f"{db_type.lower()}://{user}:{password}@{host}:{port}/{db_name}" - self.reset_db(db_type, conn_url, tables_info_dict, options_str) - - def reset_tables(self, tables_info_dict: dict) -> tuple[bool, str]: - existing_tables = set(self.get_all_tables()) - try: - tables_info = TablesInfo.model_validate(tables_info_dict) - except pydantic.ValidationError as e: - lazyllm.LOG.warning(str(e)) - return False, str(e) - for table_info in tables_info.tables: - if table_info.name not in existing_tables: - # create table - cur_rt, cur_err_msg = self._create_table(table_info.model_dump()) - else: - # check table - cur_rt, cur_err_msg = self._check_columns_match(table_info.model_dump()) - if not cur_rt: - lazyllm.LOG.warning(f"cur_err_msg: {cur_err_msg}") - return cur_rt, cur_err_msg - rt, err_msg = self._set_tables_desc_prompt(tables_info_dict) - if not rt: - lazyllm.LOG.warning(err_msg) - return True, "Success" - - def reset_db(self, db_type: str, conn_url: str, tables_info_dict: dict, options_str=""): - assert db_type in self.DB_TYPE_SUPPORTED - extra_fields = {} - if options_str: - extra_fields = { - key: value for key_value in options_str.split("&") for key, value in (key_value.split("="),) - } - self.db_type = db_type - self.conn_url = conn_url - self.extra_fields = extra_fields - self.engine = sqlalchemy.create_engine(conn_url) - self.tables_prompt = "" - rt, err_msg = self.reset_tables(tables_info_dict) - if not rt: - self.err_msg = err_msg - self.err_code = 1001 - else: - self.err_code = 0 - - def get_tables_desc(self): - return self.tables_prompt - - def check_connection(self) -> tuple[bool, str]: - try: - with self.engine.connect() as _: - return True, "Success" - except SQLAlchemyError as e: - return False, str(e) - - def get_query_result_in_json(self, sql_script) -> str: - str_result = "" - try: - with self.engine.connect() as conn: - result = conn.execute(sqlalchemy.text(sql_script)) - columns = list(result.keys()) - result_dict = [dict(zip(columns, row)) for row in result] - str_result = json.dumps(result_dict, ensure_ascii=False) - except OperationalError as e: - str_result = f"ERROR: {str(e)}" - finally: - if "conn" in locals(): - conn.close() - return str_result - - def get_all_tables(self) -> list: - inspector = sqlalchemy.inspect(self.engine) - table_names = inspector.get_table_names(schema=self.extra_fields.get("schema", None)) - return table_names - - def execute_sql_update(self, sql_script): - rt, err_msg = True, "Success" - try: - with self.engine.connect() as conn: - conn.execute(sqlalchemy.text(sql_script)) - conn.commit() - except OperationalError as e: - lazyllm.LOG.warning(f"sql error: {str(e)}") - rt, err_msg = False, str(e) - finally: - if "conn" in locals(): - conn.close() - return rt, err_msg - - def _get_table_columns(self, table_name: str): - inspector = sqlalchemy.inspect(self.engine) - columns = inspector.get_columns(table_name, schema=self.extra_fields.get("schema", None)) - return columns - - def _create_table(self, table_info_dict: dict) -> tuple[bool, str]: - rt, err_msg = True, "Success" - try: - table_info = TableInfo.model_validate(table_info_dict) - except pydantic.ValidationError as e: - return False, str(e) - try: - with self.engine.connect() as conn: - Base = declarative_base() - # build table dynamically - attrs = {"__tablename__": table_info.name} - for column_info in table_info.columns: - column_type = column_info.data_type.lower() - is_nullable = column_info.nullable - column_name = column_info.name - is_primary = column_info.is_primary_key - if column_type not in self.SUPPORTED_DATA_TYPES: - return False, f"Unsupported column type: {column_type}" - real_type = self.SUPPORTED_DATA_TYPES[column_type] - attrs[column_name] = sqlalchemy.Column(real_type, nullable=is_nullable, primary_key=is_primary) - TableClass = type(table_info.name.capitalize(), (Base,), attrs) - Base.metadata.create_all(self.engine) - except OperationalError as e: - rt, err_msg = False, f"ERROR: {str(e)}" - finally: - if "conn" in locals(): - conn.close() - return rt, err_msg - - def _delete_rows_by_name(self, table_name): - metadata = sqlalchemy.MetaData() - metadata.reflect(bind=self.engine) - rt, err_msg = True, "Success" - try: - with self.engine.connect() as conn: - table = sqlalchemy.Table(table_name, metadata, autoload_with=self.engine) - delete = table.delete() - conn.execute(delete) - conn.commit() - except SQLAlchemyError as e: - rt, err_msg = False, str(e) - return rt, err_msg - - def _drop_table_by_name(self, table_name): - metadata = sqlalchemy.MetaData() - metadata.reflect(bind=self.engine) - rt, err_msg = True, "Success" - try: - table = sqlalchemy.Table(table_name, metadata, autoload_with=self.engine) - table.drop(bind=self.engine, checkfirst=True) - except SQLAlchemyError as e: - lazyllm.LOG.warning("GET SQLAlchemyError") - rt, err_msg = False, str(e) - return rt, err_msg - - def _check_columns_match(self, table_info_dict: dict): - try: - table_info = TableInfo.model_validate(table_info_dict) - except pydantic.ValidationError as e: - return False, str(e) - real_columns = self._get_table_columns(table_info.name) - tmp_dict = {} - for real_column in real_columns: - tmp_dict[real_column["name"]] = (real_column["type"], real_column["nullable"]) - for column_info in table_info.columns: - if column_info.name not in tmp_dict: - return False, f"Table {table_info.name} exists but column {column_info.name} does not." - real_column = tmp_dict[column_info.name] - column_type = column_info.data_type.lower() - if column_type not in self.SUPPORTED_DATA_TYPES: - return False, f"Unsupported column type: {column_type}" - # 1. check data type - # string type sometimes changes to other type (such as varchar) - real_type_cls = real_column[0].__class__ - if column_type != real_type_cls.__name__.lower() and not issubclass( - real_type_cls, self.SUPPORTED_DATA_TYPES[column_type] - ): - return ( - False, - f"Table {table_info.name} exists but column {column_info.name} data_type mismatch" - f": {column_info.data_type} vs {real_column[0].__class__.__name__}", - ) - # 2. check nullable - if column_info.nullable != real_column[1]: - return False, f"Table {table_info.name} exists but column {column_info.name} nullable mismatch" - if len(tmp_dict) > len(table_info.columns): - return ( - False, - f"Table {table_info.name} exists but has more columns. {len(tmp_dict)} vs {len(table_info.columns)}", - ) - return True, "Match" - - def _set_tables_desc_prompt(self, tables_info_dict: dict) -> str: - try: - tables_info = TablesInfo.model_validate(tables_info_dict) - except pydantic.ValidationError as e: - return False, str(e) - self.tables_prompt = "The tables description is as follows\n```\n" - for table_info in tables_info.tables: - self.tables_prompt += f'Table "{table_info.name}"' - if table_info.comment: - self.tables_prompt += f' comment "{table_info.comment}"' - self.tables_prompt += "\n(\n" - for i, column_info in enumerate(table_info.columns): - self.tables_prompt += f"{column_info.name} {column_info.data_type}" - if column_info.comment: - self.tables_prompt += f' comment "{column_info.comment}"' - if i != len(table_info.columns) - 1: - self.tables_prompt += "," - self.tables_prompt += "\n" - self.tables_prompt += ");\n" - self.tables_prompt += "```\n" - return True, "Success" - - -class SQLiteManger(SqlManager): - def __init__(self, db_file, tables_info_dict: dict): - assert Path(db_file).is_file() - super().reset_db("SQLite", f"sqlite:///{db_file}", tables_info_dict) - - -sql_query_instruct_template = """ -Given the following SQL tables and current date {current_date}, your job is to write sql queries in {db_type} given a user’s request. -Alert: Just replay the sql query in a code block. - -{sql_tables} -""" # noqa E501 - -sql_explain_instruct_template = """ -According to chat history -``` -{history_info} -``` - -bellowing sql query is executed - -``` -{sql_query} -``` -the sql result is -``` -{sql_result} -``` -""" - - -class SqlCall(ModuleBase): - def __init__( - self, - llm, - sql_manager: SqlManager, - sql_examples: str = "", - use_llm_for_sql_result=True, - return_trace: bool = False, - ) -> None: - super().__init__(return_trace=return_trace) - self._sql_tool = sql_manager - self._query_prompter = ChatPrompter(instruction=sql_query_instruct_template).pre_hook(self.sql_query_promt_hook) - self._llm_query = llm.share(prompt=self._query_prompter) - self._answer_prompter = ChatPrompter(instruction=sql_explain_instruct_template).pre_hook( - self.sql_explain_prompt_hook - ) - self._llm_answer = llm.share(prompt=self._answer_prompter) - self._pattern = re.compile(r"```sql(.+?)```", re.DOTALL) - with pipeline() as sql_execute_ppl: - sql_execute_ppl.exec = self._sql_tool.get_query_result_in_json - if use_llm_for_sql_result: - sql_execute_ppl.concate = (lambda q, r: [q, r]) | bind(sql_execute_ppl.input, _0) - sql_execute_ppl.llm_answer = self._llm_answer - with pipeline() as ppl: - ppl.llm_query = self._llm_query - ppl.sql_extractor = self.extract_sql_from_response - with switch(judge_on_full_input=False) as ppl.sw: - ppl.sw.case[False, lambda x: x] - ppl.sw.case[True, sql_execute_ppl] - self._impl = ppl - - def sql_query_promt_hook( - self, - input: Union[str, List, Dict[str, str], None] = None, - history: List[Union[List[str], Dict[str, Any]]] = [], - tools: Union[List[Dict[str, Any]], None] = None, - label: Union[str, None] = None, - ): - current_date = datetime.datetime.now().strftime("%Y-%m-%d") - sql_tables_info = self._sql_tool.get_tables_desc() - if not isinstance(input, str): - raise ValueError(f"Unexpected type for input: {type(input)}") - return ( - dict( - current_date=current_date, db_type=self._sql_tool.db_type, sql_tables=sql_tables_info, user_query=input - ), - history, - tools, - label, - ) - - def sql_explain_prompt_hook( - self, - input: Union[str, List, Dict[str, str], None] = None, - history: List[Union[List[str], Dict[str, Any]]] = [], - tools: Union[List[Dict[str, Any]], None] = None, - label: Union[str, None] = None, - ): - explain_query = "Tell the user based on the sql execution results, making sure to keep the language consistent \ - with the user's input and don't translate original result." - if not isinstance(input, list) and len(input) != 2: - raise ValueError(f"Unexpected type for input: {type(input)}") - assert "root_input" in globals and self._llm_answer._module_id in globals["root_input"] - user_query = globals["root_input"][self._llm_answer._module_id] - globals.pop("root_input") - history_info = chat_history_to_str(history, user_query) - return ( - dict(history_info=history_info, sql_query=input[0], sql_result=input[1], explain_query=explain_query), - history, - tools, - label, - ) - - def extract_sql_from_response(self, str_response: str) -> tuple[bool, str]: - # Remove the triple backticks if present - matches = self._pattern.findall(str_response) - if matches: - # Return the first match - extracted_content = matches[0].strip() - return True, extracted_content - else: - return False, str_response - - def forward(self, input: str, llm_chat_history: List[Dict[str, Any]] = None): - globals["root_input"] = {self._llm_answer._module_id: input} - if self._module_id in globals["chat_history"]: - globals["chat_history"][self._llm_query._module_id] = globals["chat_history"][self._module_id] - return self._impl(input) diff --git a/tests/charge_tests/test_engine.py b/tests/charge_tests/test_engine.py index 5fcee548..e9b502cb 100644 --- a/tests/charge_tests/test_engine.py +++ b/tests/charge_tests/test_engine.py @@ -2,7 +2,7 @@ from lazyllm.engine import LightEngine import pytest from .utils import SqlEgsData, get_sql_init_keywords -from lazyllm.tools import SqlManager +from lazyllm.tools import SqlManager, DBStatus from .tools import (get_current_weather_code, get_current_weather_vars, get_current_weather_doc, get_n_day_weather_forecast_code, multiply_tool_code, add_tool_code, dummy_code) @@ -128,9 +128,11 @@ def test_sql_call(self): # 1. Init: insert data to database tmp_sql_manager = SqlManager(db_type, username, password, host, port, database, SqlEgsData.TEST_TABLES_INFO) for table_name in SqlEgsData.TEST_TABLES: - rt, err_msg = tmp_sql_manager._delete_rows_by_name(table_name) + db_result = tmp_sql_manager.delete(table_name) + assert db_result.status == DBStatus.SUCCESS for insert_script in SqlEgsData.TEST_INSERT_SCRIPTS: - tmp_sql_manager.execute_sql_update(insert_script) + db_result = tmp_sql_manager.execute(insert_script) + assert db_result.status == DBStatus.SUCCESS # 2. Engine: build and chat resources = [ @@ -160,7 +162,8 @@ def test_sql_call(self): # 3. Release: delete data and table from database for table_name in SqlEgsData.TEST_TABLES: - rt, err_msg = tmp_sql_manager._drop_table_by_name(table_name) + db_result = tmp_sql_manager.drop(table_name) + assert db_result.status == DBStatus.SUCCESS def test_register_tools(self): resources = [ diff --git a/tests/charge_tests/test_sql_tool.py b/tests/charge_tests/test_sql_tool.py index a6dc0d51..dca131ff 100644 --- a/tests/charge_tests/test_sql_tool.py +++ b/tests/charge_tests/test_sql_tool.py @@ -1,5 +1,5 @@ import unittest -from lazyllm.tools import SQLiteManger, SqlCall, SqlManager +from lazyllm.tools import SQLiteManger, SqlCall, SqlManager, DBStatus import lazyllm import tempfile from pathlib import Path @@ -11,11 +11,13 @@ class TestSqlManager(unittest.TestCase): @classmethod - def clean_obsolete_tables(cls, sql_manager): + def clean_obsolete_tables(cls, sql_manager: SqlManager): today = datetime.datetime.now() pattern = r"^(?:employee|sales)_(\d{8})_(\w+)" OBSOLETE_DAYS = 2 - existing_tables = sql_manager.get_all_tables() + db_result = sql_manager.get_all_tables() + assert db_result.status == DBStatus.SUCCESS, db_result.detail + existing_tables = db_result.result for table_name in existing_tables: match = re.match(pattern, table_name) if not match: @@ -23,7 +25,7 @@ def clean_obsolete_tables(cls, sql_manager): table_create_date = datetime.datetime.strptime(match.group(1), "%Y%m%d") delta = (today - table_create_date).days if delta >= OBSOLETE_DAYS: - sql_manager._drop_table_by_name(table_name) + sql_manager.drop(table_name) @classmethod def setUpClass(cls): @@ -31,10 +33,8 @@ def setUpClass(cls): filepath = str(Path(tempfile.gettempdir()) / f"{str(uuid.uuid4().hex)}.db") cls.db_filepath = filepath - with open(filepath, "w") as _: - pass cls.sql_managers.append(SQLiteManger(filepath, SqlEgsData.TEST_TABLES_INFO)) - for db_type in ["PostgreSQL"]: + for db_type in []: # ["PostgreSQL"]: username, password, host, port, database = get_sql_init_keywords(db_type) cls.sql_managers.append( SqlManager(db_type, username, password, host, port, database, SqlEgsData.TEST_TABLES_INFO) @@ -42,10 +42,11 @@ def setUpClass(cls): for sql_manager in cls.sql_managers: cls.clean_obsolete_tables(sql_manager) for table_name in SqlEgsData.TEST_TABLES: - rt, err_msg = sql_manager._delete_rows_by_name(table_name) - assert rt, err_msg + db_result = sql_manager.delete(table_name) + assert db_result.status == DBStatus.SUCCESS, db_result.detail for insert_script in SqlEgsData.TEST_INSERT_SCRIPTS: - sql_manager.execute_sql_update(insert_script) + db_result = sql_manager.execute(insert_script) + assert db_result.status == DBStatus.SUCCESS, db_result.detail # Recommend to use sensenova, gpt-4o, qwen online model sql_llm = lazyllm.OnlineChatModule(source="sensenova") @@ -58,58 +59,60 @@ def tearDownClass(cls): # restore to clean database for sql_manager in cls.sql_managers: for table_name in SqlEgsData.TEST_TABLES: - rt, err_msg = sql_manager._drop_table_by_name(table_name) - assert rt, f"sql_manager table {table_name} error: {err_msg}" + db_result = sql_manager.drop(table_name) + assert db_result.status == DBStatus.SUCCESS, db_result.detail db_path = Path(cls.db_filepath) if db_path.is_file(): db_path.unlink() def test_manager_status(self): for sql_manager in self.sql_managers: - rt, err_msg = sql_manager.check_connection() - assert rt, err_msg - assert sql_manager.err_code == 0 + db_result = sql_manager.check_connection() + assert db_result.status == DBStatus.SUCCESS, db_result.detail def test_manager_table_create_drop(self): for sql_manager in self.sql_managers: # 1. drop tables for table_name in SqlEgsData.TEST_TABLES: - rt, err_msg = sql_manager._drop_table_by_name(table_name) - assert rt, err_msg - existing_tables = set(sql_manager.get_all_tables()) + db_result = sql_manager.drop(table_name) + assert db_result.status == DBStatus.SUCCESS, db_result.detail + db_result = sql_manager.get_all_tables() + assert db_result.status == DBStatus.SUCCESS, db_result.detail + existing_tables = set(db_result.result) for table_name in SqlEgsData.TEST_TABLES: assert table_name not in existing_tables # 2. create table - rt, err_msg = sql_manager.reset_tables(SqlEgsData.TEST_TABLES_INFO) - assert rt, err_msg + db_result = sql_manager.reset_table_info_dict(SqlEgsData.TEST_TABLES_INFO) + assert db_result.status == DBStatus.SUCCESS, db_result.detail # 3. restore rows for insert_script in SqlEgsData.TEST_INSERT_SCRIPTS: - rt, err_msg = sql_manager.execute_sql_update(insert_script) - assert rt, err_msg + db_result = sql_manager.execute(insert_script) + assert db_result.status == DBStatus.SUCCESS, db_result.detail def test_manager_table_delete_insert_query(self): # 1. Delete, as rows already exists during setUp for sql_manager in self.sql_managers: for table_name in SqlEgsData.TEST_TABLES: - rt, err_msg = sql_manager._delete_rows_by_name(table_name) - assert rt, err_msg - str_results = sql_manager.get_query_result_in_json(SqlEgsData.TEST_QUERY_SCRIPTS) + db_result = sql_manager.delete(table_name) + assert db_result.status == DBStatus.SUCCESS, db_result.detail + str_results = sql_manager.execute_to_json(SqlEgsData.TEST_QUERY_SCRIPTS) self.assertNotIn("销售一部", str_results) # 2. Insert, restore rows for sql_manager in self.sql_managers: for insert_script in SqlEgsData.TEST_INSERT_SCRIPTS: - rt, err_msg = sql_manager.execute_sql_update(insert_script) - assert rt, err_msg - str_results = sql_manager.get_query_result_in_json(SqlEgsData.TEST_QUERY_SCRIPTS) - self.assertIn("销售一部", str_results) + db_result = sql_manager.execute(insert_script) + assert db_result.status == DBStatus.SUCCESS, db_result.detail + str_results = sql_manager.execute_to_json(SqlEgsData.TEST_QUERY_SCRIPTS) + self.assertIn("销售一部", f"Query: {SqlEgsData.TEST_QUERY_SCRIPTS}; result: {str_results}") - def test_get_talbes(self): + def test_get_tables(self): for sql_manager in self.sql_managers: - tables_desc = sql_manager.get_tables_desc() - self.assertIn("employee", tables_desc) - self.assertIn("sales", tables_desc) + db_result = sql_manager.get_all_tables() + assert db_result.status == DBStatus.SUCCESS, db_result.detail + for table_name in SqlEgsData.TEST_TABLES: + self.assertIn(table_name, db_result.result) def test_llm_query_online(self): for sql_call in self.sql_calls: