Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support calc tokens precisely #614

Merged
merged 2 commits into from
Mar 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@ pip3 install --upgrade openai
**(3) 拓展依赖 (可选):**

语音识别及语音回复相关依赖:[#415](https://github.com/zhayujie/chatgpt-on-wechat/issues/415)

让会话token数量的计算更加精准:
```bash
pip3 install --upgrade tiktoken
```

## 配置

Expand Down
74 changes: 61 additions & 13 deletions bot/chatgpt/chat_gpt_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self):
if conf().get('open_ai_api_base'):
openai.api_base = conf().get('open_ai_api_base')
proxy = conf().get('proxy')
self.sessions = SessionManager()
self.sessions = SessionManager(model= conf().get("model") or "gpt-3.5-turbo")
if proxy:
openai.proxy = proxy
if conf().get('rate_limit_chatgpt'):
Expand Down Expand Up @@ -53,7 +53,7 @@ def reply(self, query, context=None):
# return self.reply_text_stream(query, new_query, session_id)

reply_content = self.reply_text(session, session_id, 0)
logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}".format(session, session_id, reply_content["content"]))
logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(session, session_id, reply_content["content"], reply_content["completion_tokens"]))
if reply_content['completion_tokens'] == 0 and len(reply_content['content']) > 0:
reply = Reply(ReplyType.ERROR, reply_content['content'])
elif reply_content["completion_tokens"] > 0:
Expand Down Expand Up @@ -166,14 +166,14 @@ def compose_args(self):
del(args["model"])
return args


class SessionManager(object):
def __init__(self):
def __init__(self, model = "gpt-3.5-turbo-0301"):
if conf().get('expires_in_seconds'):
sessions = ExpiredDict(conf().get('expires_in_seconds'))
else:
sessions = dict()
self.sessions = sessions
self.model = model

def build_session(self, session_id, system_prompt=None):
session = self.sessions.get(session_id, [])
Expand Down Expand Up @@ -201,38 +201,86 @@ def build_session_query(self, query, session_id):
session = self.build_session(session_id)
user_item = {'role': 'user', 'content': query}
session.append(user_item)
try:
total_tokens = num_tokens_from_messages(session, self.model)
max_tokens = conf().get("conversation_max_tokens", 1000)
total_tokens = self.discard_exceed_conversation(session, max_tokens, total_tokens)
logger.debug("prompt tokens used={}".format(total_tokens))
except Exception as e:
logger.debug("Exception when counting tokens precisely for prompt: {}".format(str(e)))

return session

def save_session(self, answer, session_id, total_tokens):
max_tokens = conf().get("conversation_max_tokens")
if not max_tokens:
# default 3000
max_tokens = 1000
max_tokens = int(max_tokens)

max_tokens = conf().get("conversation_max_tokens", 1000)
session = self.sessions.get(session_id)
if session:
# append conversation
gpt_item = {'role': 'assistant', 'content': answer}
session.append(gpt_item)

# discard exceed limit conversation
self.discard_exceed_conversation(session, max_tokens, total_tokens)
tokens_cnt = self.discard_exceed_conversation(session, max_tokens, total_tokens)
logger.debug("raw total_tokens={}, savesession tokens={}".format(total_tokens, tokens_cnt))

def discard_exceed_conversation(self, session, max_tokens, total_tokens):
dec_tokens = int(total_tokens)
# logger.info("prompt tokens used={},max_tokens={}".format(used_tokens,max_tokens))
while dec_tokens > max_tokens:
# pop first conversation
if len(session) > 3:
if len(session) > 2:
session.pop(1)
elif len(session) == 2 and session[1]["role"] == "assistant":
session.pop(1)
break
elif len(session) == 2 and session[1]["role"] == "user":
logger.warn("user message exceed max_tokens. total_tokens={}".format(dec_tokens))
break
else:
logger.debug("max_tokens={}, total_tokens={}, len(sessions)={}".format(max_tokens, dec_tokens, len(session)))
break
dec_tokens = dec_tokens - max_tokens
try:
cur_tokens = num_tokens_from_messages(session, self.model)
dec_tokens = cur_tokens
except Exception as e:
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
dec_tokens = dec_tokens - max_tokens
return dec_tokens

def clear_session(self, session_id):
self.sessions[session_id] = []

def clear_all_session(self):
self.sessions.clear()

# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
def num_tokens_from_messages(messages, model):
"""Returns the number of tokens used by a list of messages."""
import tiktoken
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
logger.debug("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
if model == "gpt-3.5-turbo":
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301")
elif model == "gpt-4":
return num_tokens_from_messages(messages, model="gpt-4-0314")
elif model == "gpt-3.5-turbo-0301":
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
tokens_per_name = -1 # if there's a name, the role is omitted
elif model == "gpt-4-0314":
tokens_per_message = 3
tokens_per_name = 1
else:
logger.warn(f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo-0301.")
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301")
num_tokens = 0
for message in messages:
num_tokens += tokens_per_message
for key, value in message.items():
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens