Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support local llm model token usage and add escape for database password #353

Merged
merged 4 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion lazyllm/module/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,33 @@ def _set_url(self, url):
LOG.debug(f'url: {url}')
self.__url = url

def _estimate_token_usage(self, text):
if not isinstance(text, str):
return 0
# extract english words, number and comma
pattern = r"\b[a-zA-Z0-9]+\b|,"
ascii_words = re.findall(pattern, text)
ascii_ch_count = sum(len(ele) for ele in ascii_words)
non_ascii_pattern = r"[^\x00-\x7F]"
non_ascii_chars = re.findall(non_ascii_pattern, text)
non_ascii_char_count = len(non_ascii_chars)
return int(ascii_ch_count / 3.0 + non_ascii_char_count + 1)

def _record_usage(self, usage: dict):
globals["usage"][self._module_id] = usage
par_muduleid = self._used_by_moduleid
if par_muduleid is None:
return
if par_muduleid not in globals["usage"]:
globals["usage"][par_muduleid] = usage
return
existing_usage = globals["usage"][par_muduleid]
if existing_usage["prompt_tokens"] == -1 or usage["prompt_tokens"] == -1:
globals["usage"][par_muduleid] = {"prompt_tokens": -1, "completion_tokens": -1}
else:
for k in globals["usage"][par_muduleid]:
globals["usage"][par_muduleid][k] += usage[k]

# Cannot modify or add any attrubute of self
# prompt keys (excluding history) are in __input (ATTENTION: dict, not kwargs)
# deploy parameters keys are in **kw
Expand All @@ -336,6 +363,7 @@ def forward(self, __input: Union[Tuple[Union[str, Dict], str], str, Dict] = pack
query = __input
__input = self._prompt.generate_prompt(query, llm_chat_history, tools)
headers = {'Content-Type': 'application/json'}
text_input_for_token_usage = __input

if isinstance(self, ServerModule):
assert llm_chat_history is None and tools is None
Expand Down Expand Up @@ -396,7 +424,12 @@ def forward(self, __input: Union[Tuple[Union[str, Dict], str], str, Dict] = pack
cache = ""
else:
raise requests.RequestException('\n'.join([c.decode('utf-8') for c in r.iter_content(None)]))
return self._formatter.format(self._extract_and_format(messages))
temp_output = self._extract_and_format(messages)
if isinstance(self, TrainableModule):
usage = {"prompt_tokens": self._estimate_token_usage(text_input_for_token_usage)}
usage["completion_tokens"] = self._estimate_token_usage(temp_output)
self._record_usage(usage)
return self._formatter.format(temp_output)

def prompt(self, prompt=None):
if prompt is None:
Expand Down
5 changes: 3 additions & 2 deletions lazyllm/tools/sql/sql_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from lazyllm import pipeline, globals, bind, _0, switch
import json
from typing import List, Any, Dict, Union
from pathlib import Path
import datetime
import re
import sqlalchemy
from sqlalchemy.exc import SQLAlchemyError, OperationalError
from sqlalchemy.orm import declarative_base
import pydantic
from urllib.parse import quote_plus


class ColumnInfo(pydantic.BaseModel):
Expand Down Expand Up @@ -55,6 +55,7 @@ def __init__(
) -> None:
super().__init__()
if db_type.lower() != "sqlite":
password = quote_plus(password)
conn_url = f"{db_type.lower()}://{user}:{password}@{host}:{port}/{db_name}"
self.reset_db(db_type, conn_url, tables_info_dict, options_str)

Expand Down Expand Up @@ -267,8 +268,8 @@ def _set_tables_desc_prompt(self, tables_info_dict: dict) -> str:


class SQLiteManger(SqlManager):

def __init__(self, db_file, tables_info_dict: dict):
assert Path(db_file).is_file()
super().__init__("SQLite", "", "", "", 0, "", {}, "")
super().reset_db("SQLite", f"sqlite:///{db_file}", tables_info_dict)

Expand Down
14 changes: 1 addition & 13 deletions tests/charge_tests/test_sql_tool.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import unittest
from lazyllm.tools import SQLiteManger, SqlCall, SqlManager
import lazyllm
import tempfile
from pathlib import Path
import uuid
from .utils import SqlEgsData, get_sql_init_keywords
import datetime
import re
Expand All @@ -27,13 +24,7 @@ def clean_obsolete_tables(cls, sql_manager):

@classmethod
def setUpClass(cls):
cls.sql_managers: list[SqlManager] = []

filepath = str(Path(tempfile.gettempdir()) / f"{str(uuid.uuid4().hex)}.db")
cls.db_filepath = filepath
with open(filepath, "w") as _:
pass
cls.sql_managers.append(SQLiteManger(filepath, SqlEgsData.TEST_TABLES_INFO))
cls.sql_managers: list[SqlManager] = [SQLiteManger(":memory:", SqlEgsData.TEST_TABLES_INFO)]
for db_type in ["PostgreSQL"]:
username, password, host, port, database = get_sql_init_keywords(db_type)
cls.sql_managers.append(
Expand All @@ -60,9 +51,6 @@ def tearDownClass(cls):
for table_name in SqlEgsData.TEST_TABLES:
rt, err_msg = sql_manager._drop_table_by_name(table_name)
assert rt, f"sql_manager table {table_name} error: {err_msg}"
db_path = Path(cls.db_filepath)
if db_path.is_file():
db_path.unlink()

def test_manager_status(self):
for sql_manager in self.sql_managers:
Expand Down
2 changes: 1 addition & 1 deletion tests/charge_tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def get_sql_init_keywords(db_type):
env_key = f"LAZYLLM_{db_type.replace(' ', '_')}_URL"
conn_url = os.environ.get(env_key, None)
assert conn_url is not None
pattern = r"postgresql://(?P<username>[^:]+):(?P<password>[^@]+)@(?P<host>[^:]+):(?P<port>\d+)/(?P<database>.+)"
pattern = r"postgresql://(?P<username>[^:]+):(?P<password>.+)@(?P<host>[^:]+):(?P<port>\d+)/(?P<database>.+)"
match = re.search(pattern, conn_url)
assert match
username = match.group("username")
Expand Down