Skip to content

Commit

Permalink
1. Support local llm model token usage. 2. Add escape for database pa…
Browse files Browse the repository at this point in the history
…ssword
  • Loading branch information
zhangyongchao committed Nov 20, 2024
1 parent c2bb6ff commit d4d162f
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 13 deletions.
32 changes: 31 additions & 1 deletion lazyllm/module/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,31 @@ def _set_url(self, url):
LOG.debug(f'url: {url}')
self.__url = url

def _estimate_token_usage(self, text):
# extract english words, number and comma
pattern = r"\b[a-zA-Z0-9]+\b|,"
ascii_words = re.findall(pattern, text)
ascii_ch_count = sum(len(ele) for ele in ascii_words)
non_ascii_pattern = r"[^\x00-\x7F]"
non_ascii_chars = re.findall(non_ascii_pattern, text)
non_ascii_char_count = len(non_ascii_chars)
return int(ascii_ch_count / 3.0 + non_ascii_char_count + 1)

def _record_usage(self, usage: dict):
globals["usage"][self._module_id] = usage
par_muduleid = self._used_by_moduleid
if par_muduleid is None:
return
if par_muduleid not in globals["usage"]:
globals["usage"][par_muduleid] = usage
return
existing_usage = globals["usage"][par_muduleid]
if existing_usage["prompt_tokens"] == -1 or usage["prompt_tokens"] == -1:
globals["usage"][par_muduleid] = {"prompt_tokens": -1, "completion_tokens": -1}
else:
for k in globals["usage"][par_muduleid]:
globals["usage"][par_muduleid][k] += usage[k]

# Cannot modify or add any attrubute of self
# prompt keys (excluding history) are in __input (ATTENTION: dict, not kwargs)
# deploy parameters keys are in **kw
Expand All @@ -337,6 +362,8 @@ def forward(self, __input: Union[Tuple[Union[str, Dict], str], str, Dict] = pack
__input = self._prompt.generate_prompt(query, llm_chat_history, tools)
headers = {'Content-Type': 'application/json'}

usage = {"prompt_tokens": self._estimate_token_usage(__input), "completion_tokens": 0}

if isinstance(self, ServerModule):
assert llm_chat_history is None and tools is None
headers['Global-Parameters'] = encode_request(globals._pickle_data)
Expand Down Expand Up @@ -396,7 +423,10 @@ def forward(self, __input: Union[Tuple[Union[str, Dict], str], str, Dict] = pack
cache = ""
else:
raise requests.RequestException('\n'.join([c.decode('utf-8') for c in r.iter_content(None)]))
return self._formatter.format(self._extract_and_format(messages))
temp_output = self._extract_and_format(messages)
usage["completion_tokens"] = self._estimate_token_usage(temp_output)
self._record_usage(usage)
return self._formatter.format(temp_output)

def prompt(self, prompt=None):
if prompt is None:
Expand Down
4 changes: 3 additions & 1 deletion lazyllm/tools/sql/sql_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sqlalchemy.exc import SQLAlchemyError, OperationalError
from sqlalchemy.orm import declarative_base
import pydantic
from urllib.parse import quote_plus


class ColumnInfo(pydantic.BaseModel):
Expand Down Expand Up @@ -55,6 +56,7 @@ def __init__(
) -> None:
super().__init__()
if db_type.lower() != "sqlite":
password = quote_plus(password)
conn_url = f"{db_type.lower()}://{user}:{password}@{host}:{port}/{db_name}"
self.reset_db(db_type, conn_url, tables_info_dict, options_str)

Expand Down Expand Up @@ -267,8 +269,8 @@ def _set_tables_desc_prompt(self, tables_info_dict: dict) -> str:


class SQLiteManger(SqlManager):

def __init__(self, db_file, tables_info_dict: dict):
assert Path(db_file).is_file()
super().__init__("SQLite", "", "", "", 0, "", {}, "")
super().reset_db("SQLite", f"sqlite:///{db_file}", tables_info_dict)

Expand Down
11 changes: 1 addition & 10 deletions tests/charge_tests/test_sql_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,7 @@ def clean_obsolete_tables(cls, sql_manager):

@classmethod
def setUpClass(cls):
cls.sql_managers: list[SqlManager] = []

filepath = str(Path(tempfile.gettempdir()) / f"{str(uuid.uuid4().hex)}.db")
cls.db_filepath = filepath
with open(filepath, "w") as _:
pass
cls.sql_managers.append(SQLiteManger(filepath, SqlEgsData.TEST_TABLES_INFO))
cls.sql_managers: list[SqlManager] = [SQLiteManger(":memory:", SqlEgsData.TEST_TABLES_INFO)]
for db_type in ["PostgreSQL"]:
username, password, host, port, database = get_sql_init_keywords(db_type)
cls.sql_managers.append(
Expand All @@ -60,9 +54,6 @@ def tearDownClass(cls):
for table_name in SqlEgsData.TEST_TABLES:
rt, err_msg = sql_manager._drop_table_by_name(table_name)
assert rt, f"sql_manager table {table_name} error: {err_msg}"
db_path = Path(cls.db_filepath)
if db_path.is_file():
db_path.unlink()

def test_manager_status(self):
for sql_manager in self.sql_managers:
Expand Down
2 changes: 1 addition & 1 deletion tests/charge_tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def get_sql_init_keywords(db_type):
env_key = f"LAZYLLM_{db_type.replace(' ', '_')}_URL"
conn_url = os.environ.get(env_key, None)
assert conn_url is not None
pattern = r"postgresql://(?P<username>[^:]+):(?P<password>[^@]+)@(?P<host>[^:]+):(?P<port>\d+)/(?P<database>.+)"
pattern = r"postgresql://(?P<username>[^:]+):(?P<password>.+)@(?P<host>[^:]+):(?P<port>\d+)/(?P<database>.+)"
match = re.search(pattern, conn_url)
assert match
username = match.group("username")
Expand Down

0 comments on commit d4d162f

Please sign in to comment.