Skip to content

Commit

Permalink
Support local llm model token usage and add escape for database passw…
Browse files Browse the repository at this point in the history
…ord (#353)

Co-authored-by: zhangyongchao <[email protected]>
  • Loading branch information
SuperEver and zhangyongchao authored Nov 21, 2024
1 parent ce54d52 commit e311e8d
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 17 deletions.
35 changes: 34 additions & 1 deletion lazyllm/module/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,33 @@ def _set_url(self, url):
LOG.debug(f'url: {url}')
self.__url = url

def _estimate_token_usage(self, text):
if not isinstance(text, str):
return 0
# extract english words, number and comma
pattern = r"\b[a-zA-Z0-9]+\b|,"
ascii_words = re.findall(pattern, text)
ascii_ch_count = sum(len(ele) for ele in ascii_words)
non_ascii_pattern = r"[^\x00-\x7F]"
non_ascii_chars = re.findall(non_ascii_pattern, text)
non_ascii_char_count = len(non_ascii_chars)
return int(ascii_ch_count / 3.0 + non_ascii_char_count + 1)

def _record_usage(self, usage: dict):
globals["usage"][self._module_id] = usage
par_muduleid = self._used_by_moduleid
if par_muduleid is None:
return
if par_muduleid not in globals["usage"]:
globals["usage"][par_muduleid] = usage
return
existing_usage = globals["usage"][par_muduleid]
if existing_usage["prompt_tokens"] == -1 or usage["prompt_tokens"] == -1:
globals["usage"][par_muduleid] = {"prompt_tokens": -1, "completion_tokens": -1}
else:
for k in globals["usage"][par_muduleid]:
globals["usage"][par_muduleid][k] += usage[k]

# Cannot modify or add any attrubute of self
# prompt keys (excluding history) are in __input (ATTENTION: dict, not kwargs)
# deploy parameters keys are in **kw
Expand All @@ -336,6 +363,7 @@ def forward(self, __input: Union[Tuple[Union[str, Dict], str], str, Dict] = pack
query = __input
__input = self._prompt.generate_prompt(query, llm_chat_history, tools)
headers = {'Content-Type': 'application/json'}
text_input_for_token_usage = __input

if isinstance(self, ServerModule):
assert llm_chat_history is None and tools is None
Expand Down Expand Up @@ -396,7 +424,12 @@ def forward(self, __input: Union[Tuple[Union[str, Dict], str], str, Dict] = pack
cache = ""
else:
raise requests.RequestException('\n'.join([c.decode('utf-8') for c in r.iter_content(None)]))
return self._formatter.format(self._extract_and_format(messages))
temp_output = self._extract_and_format(messages)
if isinstance(self, TrainableModule):
usage = {"prompt_tokens": self._estimate_token_usage(text_input_for_token_usage)}
usage["completion_tokens"] = self._estimate_token_usage(temp_output)
self._record_usage(usage)
return self._formatter.format(temp_output)

def prompt(self, prompt=None):
if prompt is None:
Expand Down
5 changes: 3 additions & 2 deletions lazyllm/tools/sql/sql_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from lazyllm import pipeline, globals, bind, _0, switch
import json
from typing import List, Any, Dict, Union
from pathlib import Path
import datetime
import re
import sqlalchemy
from sqlalchemy.exc import SQLAlchemyError, OperationalError
from sqlalchemy.orm import declarative_base
import pydantic
from urllib.parse import quote_plus


class ColumnInfo(pydantic.BaseModel):
Expand Down Expand Up @@ -55,6 +55,7 @@ def __init__(
) -> None:
super().__init__()
if db_type.lower() != "sqlite":
password = quote_plus(password)
conn_url = f"{db_type.lower()}://{user}:{password}@{host}:{port}/{db_name}"
self.reset_db(db_type, conn_url, tables_info_dict, options_str)

Expand Down Expand Up @@ -267,8 +268,8 @@ def _set_tables_desc_prompt(self, tables_info_dict: dict) -> str:


class SQLiteManger(SqlManager):

def __init__(self, db_file, tables_info_dict: dict):
assert Path(db_file).is_file()
super().__init__("SQLite", "", "", "", 0, "", {}, "")
super().reset_db("SQLite", f"sqlite:///{db_file}", tables_info_dict)

Expand Down
14 changes: 1 addition & 13 deletions tests/charge_tests/test_sql_tool.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import unittest
from lazyllm.tools import SQLiteManger, SqlCall, SqlManager
import lazyllm
import tempfile
from pathlib import Path
import uuid
from .utils import SqlEgsData, get_sql_init_keywords
import datetime
import re
Expand All @@ -27,13 +24,7 @@ def clean_obsolete_tables(cls, sql_manager):

@classmethod
def setUpClass(cls):
cls.sql_managers: list[SqlManager] = []

filepath = str(Path(tempfile.gettempdir()) / f"{str(uuid.uuid4().hex)}.db")
cls.db_filepath = filepath
with open(filepath, "w") as _:
pass
cls.sql_managers.append(SQLiteManger(filepath, SqlEgsData.TEST_TABLES_INFO))
cls.sql_managers: list[SqlManager] = [SQLiteManger(":memory:", SqlEgsData.TEST_TABLES_INFO)]
for db_type in ["PostgreSQL"]:
username, password, host, port, database = get_sql_init_keywords(db_type)
cls.sql_managers.append(
Expand All @@ -60,9 +51,6 @@ def tearDownClass(cls):
for table_name in SqlEgsData.TEST_TABLES:
rt, err_msg = sql_manager._drop_table_by_name(table_name)
assert rt, f"sql_manager table {table_name} error: {err_msg}"
db_path = Path(cls.db_filepath)
if db_path.is_file():
db_path.unlink()

def test_manager_status(self):
for sql_manager in self.sql_managers:
Expand Down
2 changes: 1 addition & 1 deletion tests/charge_tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def get_sql_init_keywords(db_type):
env_key = f"LAZYLLM_{db_type.replace(' ', '_')}_URL"
conn_url = os.environ.get(env_key, None)
assert conn_url is not None
pattern = r"postgresql://(?P<username>[^:]+):(?P<password>[^@]+)@(?P<host>[^:]+):(?P<port>\d+)/(?P<database>.+)"
pattern = r"postgresql://(?P<username>[^:]+):(?P<password>.+)@(?P<host>[^:]+):(?P<port>\d+)/(?P<database>.+)"
match = re.search(pattern, conn_url)
assert match
username = match.group("username")
Expand Down

0 comments on commit e311e8d

Please sign in to comment.