diff --git a/lazyllm/module/module.py b/lazyllm/module/module.py index 64154268..7742dfb3 100644 --- a/lazyllm/module/module.py +++ b/lazyllm/module/module.py @@ -313,6 +313,33 @@ def _set_url(self, url): LOG.debug(f'url: {url}') self.__url = url + def _estimate_token_usage(self, text): + if not isinstance(text, str): + return 0 + # extract english words, number and comma + pattern = r"\b[a-zA-Z0-9]+\b|," + ascii_words = re.findall(pattern, text) + ascii_ch_count = sum(len(ele) for ele in ascii_words) + non_ascii_pattern = r"[^\x00-\x7F]" + non_ascii_chars = re.findall(non_ascii_pattern, text) + non_ascii_char_count = len(non_ascii_chars) + return int(ascii_ch_count / 3.0 + non_ascii_char_count + 1) + + def _record_usage(self, usage: dict): + globals["usage"][self._module_id] = usage + par_muduleid = self._used_by_moduleid + if par_muduleid is None: + return + if par_muduleid not in globals["usage"]: + globals["usage"][par_muduleid] = usage + return + existing_usage = globals["usage"][par_muduleid] + if existing_usage["prompt_tokens"] == -1 or usage["prompt_tokens"] == -1: + globals["usage"][par_muduleid] = {"prompt_tokens": -1, "completion_tokens": -1} + else: + for k in globals["usage"][par_muduleid]: + globals["usage"][par_muduleid][k] += usage[k] + # Cannot modify or add any attrubute of self # prompt keys (excluding history) are in __input (ATTENTION: dict, not kwargs) # deploy parameters keys are in **kw @@ -336,6 +363,7 @@ def forward(self, __input: Union[Tuple[Union[str, Dict], str], str, Dict] = pack query = __input __input = self._prompt.generate_prompt(query, llm_chat_history, tools) headers = {'Content-Type': 'application/json'} + text_input_for_token_usage = __input if isinstance(self, ServerModule): assert llm_chat_history is None and tools is None @@ -396,7 +424,12 @@ def forward(self, __input: Union[Tuple[Union[str, Dict], str], str, Dict] = pack cache = "" else: raise requests.RequestException('\n'.join([c.decode('utf-8') for c in r.iter_content(None)])) - return self._formatter.format(self._extract_and_format(messages)) + temp_output = self._extract_and_format(messages) + if isinstance(self, TrainableModule): + usage = {"prompt_tokens": self._estimate_token_usage(text_input_for_token_usage)} + usage["completion_tokens"] = self._estimate_token_usage(temp_output) + self._record_usage(usage) + return self._formatter.format(temp_output) def prompt(self, prompt=None): if prompt is None: diff --git a/lazyllm/tools/sql/sql_tool.py b/lazyllm/tools/sql/sql_tool.py index ef2157fd..7a31280d 100644 --- a/lazyllm/tools/sql/sql_tool.py +++ b/lazyllm/tools/sql/sql_tool.py @@ -5,13 +5,13 @@ from lazyllm import pipeline, globals, bind, _0, switch import json from typing import List, Any, Dict, Union -from pathlib import Path import datetime import re import sqlalchemy from sqlalchemy.exc import SQLAlchemyError, OperationalError from sqlalchemy.orm import declarative_base import pydantic +from urllib.parse import quote_plus class ColumnInfo(pydantic.BaseModel): @@ -55,6 +55,7 @@ def __init__( ) -> None: super().__init__() if db_type.lower() != "sqlite": + password = quote_plus(password) conn_url = f"{db_type.lower()}://{user}:{password}@{host}:{port}/{db_name}" self.reset_db(db_type, conn_url, tables_info_dict, options_str) @@ -267,8 +268,8 @@ def _set_tables_desc_prompt(self, tables_info_dict: dict) -> str: class SQLiteManger(SqlManager): + def __init__(self, db_file, tables_info_dict: dict): - assert Path(db_file).is_file() super().__init__("SQLite", "", "", "", 0, "", {}, "") super().reset_db("SQLite", f"sqlite:///{db_file}", tables_info_dict) diff --git a/tests/charge_tests/test_sql_tool.py b/tests/charge_tests/test_sql_tool.py index 799cefd0..423327a7 100644 --- a/tests/charge_tests/test_sql_tool.py +++ b/tests/charge_tests/test_sql_tool.py @@ -1,9 +1,6 @@ import unittest from lazyllm.tools import SQLiteManger, SqlCall, SqlManager import lazyllm -import tempfile -from pathlib import Path -import uuid from .utils import SqlEgsData, get_sql_init_keywords import datetime import re @@ -27,13 +24,7 @@ def clean_obsolete_tables(cls, sql_manager): @classmethod def setUpClass(cls): - cls.sql_managers: list[SqlManager] = [] - - filepath = str(Path(tempfile.gettempdir()) / f"{str(uuid.uuid4().hex)}.db") - cls.db_filepath = filepath - with open(filepath, "w") as _: - pass - cls.sql_managers.append(SQLiteManger(filepath, SqlEgsData.TEST_TABLES_INFO)) + cls.sql_managers: list[SqlManager] = [SQLiteManger(":memory:", SqlEgsData.TEST_TABLES_INFO)] for db_type in ["PostgreSQL"]: username, password, host, port, database = get_sql_init_keywords(db_type) cls.sql_managers.append( @@ -60,9 +51,6 @@ def tearDownClass(cls): for table_name in SqlEgsData.TEST_TABLES: rt, err_msg = sql_manager._drop_table_by_name(table_name) assert rt, f"sql_manager table {table_name} error: {err_msg}" - db_path = Path(cls.db_filepath) - if db_path.is_file(): - db_path.unlink() def test_manager_status(self): for sql_manager in self.sql_managers: diff --git a/tests/charge_tests/utils.py b/tests/charge_tests/utils.py index 7823b6fd..ee3ecb98 100644 --- a/tests/charge_tests/utils.py +++ b/tests/charge_tests/utils.py @@ -61,7 +61,7 @@ def get_sql_init_keywords(db_type): env_key = f"LAZYLLM_{db_type.replace(' ', '_')}_URL" conn_url = os.environ.get(env_key, None) assert conn_url is not None - pattern = r"postgresql://(?P[^:]+):(?P[^@]+)@(?P[^:]+):(?P\d+)/(?P.+)" + pattern = r"postgresql://(?P[^:]+):(?P.+)@(?P[^:]+):(?P\d+)/(?P.+)" match = re.search(pattern, conn_url) assert match username = match.group("username")