Skip to content

Commit

Permalink
fix sequence2txt error and usage total token issue
Browse files Browse the repository at this point in the history
  • Loading branch information
KevinHuSh committed Oct 22, 2024
1 parent bfc07fe commit b57be25
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 13 deletions.
2 changes: 1 addition & 1 deletion api/apps/conversation_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService
from api.settings import RetCode, retrievaler
from api.utils import get_uuid
from api.utils.api_utils import get_json_result
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from graphrag.mind_map_extractor import MindMapExtractor
Expand Down Expand Up @@ -187,6 +186,7 @@ def stream():
yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
ConversationService.update_by_id(conv.id, conv.to_dict())
except Exception as e:
traceback.print_exc()
yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
ensure_ascii=False) + "\n\n"
Expand Down
3 changes: 2 additions & 1 deletion api/db/services/llm_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ def model_instance(cls, tenant_id, llm_type,
if model_config["llm_factory"] not in Seq2txtModel:
return
return Seq2txtModel[model_config["llm_factory"]](
model_config["api_key"], model_config["llm_name"], lang,
key=model_config["api_key"], model_name=model_config["llm_name"],
lang=lang,
base_url=model_config["api_base"]
)
if llm_type == LLMType.TTS:
Expand Down
9 changes: 7 additions & 2 deletions api/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,15 @@ def thumbnail_img(filename, blob):
pass
return None


def thumbnail(filename, blob):
img = thumbnail_img(filename, blob)
return IMG_BASE64_PREFIX + \
base64.b64encode(img).decode("utf-8")
if img is not None:
return IMG_BASE64_PREFIX + \
base64.b64encode(img).decode("utf-8")
else:
return ''


def traversal_files(base):
for root, ds, fs in os.walk(base):
Expand Down
18 changes: 10 additions & 8 deletions rag/llm/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,16 @@ def chat_streamly(self, system, history, gen_conf):
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
ans += resp.choices[0].delta.content
total_tokens = (
(
total_tokens
+ num_tokens_from_string(resp.choices[0].delta.content)
)
if not hasattr(resp, "usage") or not resp.usage
else resp.usage.get("total_tokens", total_tokens)
)
total_tokens += 1
if not hasattr(resp, "usage") or not resp.usage:
total_tokens = (
total_tokens
+ num_tokens_from_string(resp.choices[0].delta.content)
)
elif isinstance(resp.usage, dict):
total_tokens = resp.usage.get("total_tokens", total_tokens)
else: total_tokens = resp.usage.total_tokens

if resp.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
Expand Down
2 changes: 1 addition & 1 deletion rag/llm/sequence2txt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(self, key, model_name, lang="Chinese", **kwargs):


class XinferenceSeq2txt(Base):
def __init__(self,key,model_name="whisper-small",**kwargs):
def __init__(self, key, model_name="whisper-small", **kwargs):
self.base_url = kwargs.get('base_url', None)
self.model_name = model_name
self.key = key
Expand Down

0 comments on commit b57be25

Please sign in to comment.