Skip to content

Commit

Permalink
Merge pull request zhayujie#565 from zhayujie/dev
Browse files Browse the repository at this point in the history
feat: support plugins
  • Loading branch information
zhayujie authored Mar 24, 2023
2 parents 6aaa463 + 71353be commit 1f73f00
Show file tree
Hide file tree
Showing 48 changed files with 2,298 additions and 257 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ config.json
QR.png
nohup.out
tmp
plugins.json
7 changes: 5 additions & 2 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@
from channel import channel_factory
from common.log import logger


from plugins import *
if __name__ == '__main__':
try:
# load config
config.load_config()

# create channel
channel = channel_factory.create_channel("wx")
channel_name='wx'
channel = channel_factory.create_channel(channel_name)
if channel_name=='wx':
PluginManager().load_plugins()

# startup channel
channel.startup()
Expand Down
4 changes: 3 additions & 1 deletion bot/baidu/baidu_unit_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import requests
from bot.bot import Bot
from bridge.reply import Reply, ReplyType


# Baidu Unit对话接口 (可用, 但能力较弱)
Expand All @@ -14,7 +15,8 @@ def reply(self, query, context=None):
headers = {'content-type': 'application/x-www-form-urlencoded'}
response = requests.post(url, data=post_data.encode(), headers=headers)
if response:
return response.json()['result']['context']['SYS_PRESUMED_HIST'][1]
reply = Reply(ReplyType.TEXT, response.json()['result']['context']['SYS_PRESUMED_HIST'][1])
return reply

def get_token(self):
access_key = 'YOUR_ACCESS_KEY'
Expand Down
6 changes: 5 additions & 1 deletion bot/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
"""


from bridge.context import Context
from bridge.reply import Reply


class Bot(object):
def reply(self, query, context=None):
def reply(self, query, context : Context =None) -> Reply:
"""
bot auto-reply content
:param req: received message
Expand Down
130 changes: 77 additions & 53 deletions bot/chatgpt/chat_gpt_bot.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
# encoding:utf-8

from bot.bot import Bot
from bridge.context import ContextType
from bridge.reply import Reply, ReplyType
from config import conf, load_config
from common.log import logger
from common.token_bucket import TokenBucket
from common.expired_dict import ExpiredDict
import openai
import time

if conf().get('expires_in_seconds'):
all_sessions = ExpiredDict(conf().get('expires_in_seconds'))
else:
all_sessions = dict()

# OpenAI对话模型API (可用)
class ChatGPTBot(Bot):
Expand All @@ -20,6 +18,7 @@ def __init__(self):
if conf().get('open_ai_api_base'):
openai.api_base = conf().get('open_ai_api_base')
proxy = conf().get('proxy')
self.sessions = SessionManager()
if proxy:
openai.proxy = proxy
if conf().get('rate_limit_chatgpt'):
Expand All @@ -29,21 +28,24 @@ def __init__(self):

def reply(self, query, context=None):
# acquire reply content
if not context or not context.get('type') or context.get('type') == 'TEXT':
if context.type == ContextType.TEXT:
logger.info("[OPEN_AI] query={}".format(query))
session_id = context.get('session_id') or context.get('from_user_id')

session_id = context['session_id']
reply = None
clear_memory_commands = conf().get('clear_memory_commands', ['#清除记忆'])
if query in clear_memory_commands:
Session.clear_session(session_id)
return '记忆已清除'
self.sessions.clear_session(session_id)
reply = Reply(ReplyType.INFO, '记忆已清除')
elif query == '#清除所有':
Session.clear_all_session()
return '所有人记忆已清除'
self.sessions.clear_all_session()
reply = Reply(ReplyType.INFO, '所有人记忆已清除')
elif query == '#更新配置':
load_config()
return '配置已更新'

session = Session.build_session_query(query, session_id)
reply = Reply(ReplyType.INFO, '配置已更新')
if reply:
return reply
session = self.sessions.build_session_query(query, session_id)
logger.debug("[OPEN_AI] session query={}".format(session))

# if context.get('stream'):
Expand All @@ -52,14 +54,29 @@ def reply(self, query, context=None):

reply_content = self.reply_text(session, session_id, 0)
logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}".format(session, session_id, reply_content["content"]))
if reply_content["completion_tokens"] > 0:
Session.save_session(reply_content["content"], session_id, reply_content["total_tokens"])
return reply_content["content"]

elif context.get('type', None) == 'IMAGE_CREATE':
return self.create_img(query, 0)
if reply_content['completion_tokens'] == 0 and len(reply_content['content']) > 0:
reply = Reply(ReplyType.ERROR, reply_content['content'])
elif reply_content["completion_tokens"] > 0:
self.sessions.save_session(reply_content["content"], session_id, reply_content["total_tokens"])
reply = Reply(ReplyType.TEXT, reply_content["content"])
else:
reply = Reply(ReplyType.ERROR, reply_content['content'])
logger.debug("[OPEN_AI] reply {} used 0 tokens.".format(reply_content))
return reply

elif context.type == ContextType.IMAGE_CREATE:
ok, retstring = self.create_img(query, 0)
reply = None
if ok:
reply = Reply(ReplyType.IMAGE_URL, retstring)
else:
reply = Reply(ReplyType.ERROR, retstring)
return reply
else:
reply = Reply(ReplyType.ERROR, 'Bot不支持处理{}类型的消息'.format(context.type))
return reply

def reply_text(self, session, session_id, retry_count=0) ->dict:
def reply_text(self, session, session_id, retry_count=0) -> dict:
'''
call openai's ChatCompletion to get the answer
:param session: a conversation session
Expand All @@ -80,8 +97,8 @@ def reply_text(self, session, session_id, retry_count=0) ->dict:
presence_penalty=conf().get('presence_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
)
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
return {"total_tokens": response["usage"]["total_tokens"],
"completion_tokens": response["usage"]["completion_tokens"],
return {"total_tokens": response["usage"]["total_tokens"],
"completion_tokens": response["usage"]["completion_tokens"],
"content": response.choices[0]['message']['content']}
except openai.error.RateLimitError as e:
# rate limit exception
Expand All @@ -96,21 +113,21 @@ def reply_text(self, session, session_id, retry_count=0) ->dict:
# api connection exception
logger.warn(e)
logger.warn("[OPEN_AI] APIConnection failed")
return {"completion_tokens": 0, "content":"我连接不到你的网络"}
return {"completion_tokens": 0, "content": "我连接不到你的网络"}
except openai.error.Timeout as e:
logger.warn(e)
logger.warn("[OPEN_AI] Timeout")
return {"completion_tokens": 0, "content":"我没有收到你的消息"}
return {"completion_tokens": 0, "content": "我没有收到你的消息"}
except Exception as e:
# unknown exception
logger.exception(e)
Session.clear_session(session_id)
self.sessions.clear_session(session_id)
return {"completion_tokens": 0, "content": "请再问我一次吧"}

def create_img(self, query, retry_count=0):
try:
if conf().get('rate_limit_dalle') and not self.tb4dalle.get_token():
return "请求太快了,请休息一下再问我吧"
return False, "请求太快了,请休息一下再问我吧"
logger.info("[OPEN_AI] image_query={}".format(query))
response = openai.Image.create(
prompt=query, #图片描述
Expand All @@ -119,22 +136,39 @@ def create_img(self, query, retry_count=0):
)
image_url = response['data'][0]['url']
logger.info("[OPEN_AI] image_url={}".format(image_url))
return image_url
return True, image_url
except openai.error.RateLimitError as e:
logger.warn(e)
if retry_count < 1:
time.sleep(5)
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1))
return self.create_img(query, retry_count+1)
else:
return "请求太快啦,请休息一下再问我吧"
return False, "提问太快啦,请休息一下再问我吧"
except Exception as e:
logger.exception(e)
return None
return False, str(e)


class Session(object):
@staticmethod
def build_session_query(query, session_id):
class SessionManager(object):
def __init__(self):
if conf().get('expires_in_seconds'):
sessions = ExpiredDict(conf().get('expires_in_seconds'))
else:
sessions = dict()
self.sessions = sessions

def build_session(self, session_id, system_prompt=None):
session = self.sessions.get(session_id, [])
if len(session) == 0:
if system_prompt is None:
system_prompt = conf().get("character_desc", "")
system_item = {'role': 'system', 'content': system_prompt}
session.append(system_item)
self.sessions[session_id] = session
return session

def build_session_query(self, query, session_id):
'''
build query with conversation history
e.g. [
Expand All @@ -147,36 +181,28 @@ def build_session_query(query, session_id):
:param session_id: session id
:return: query content with conversaction
'''
session = all_sessions.get(session_id, [])
if len(session) == 0:
system_prompt = conf().get("character_desc", "")
system_item = {'role': 'system', 'content': system_prompt}
session.append(system_item)
all_sessions[session_id] = session
session = self.build_session(session_id)
user_item = {'role': 'user', 'content': query}
session.append(user_item)
return session

@staticmethod
def save_session(answer, session_id, total_tokens):
def save_session(self, answer, session_id, total_tokens):
max_tokens = conf().get("conversation_max_tokens")
if not max_tokens:
# default 3000
max_tokens = 1000
max_tokens=int(max_tokens)
max_tokens = int(max_tokens)

session = all_sessions.get(session_id)
session = self.sessions.get(session_id)
if session:
# append conversation
gpt_item = {'role': 'assistant', 'content': answer}
session.append(gpt_item)

# discard exceed limit conversation
Session.discard_exceed_conversation(session, max_tokens, total_tokens)

self.discard_exceed_conversation(session, max_tokens, total_tokens)

@staticmethod
def discard_exceed_conversation(session, max_tokens, total_tokens):
def discard_exceed_conversation(self, session, max_tokens, total_tokens):
dec_tokens = int(total_tokens)
# logger.info("prompt tokens used={},max_tokens={}".format(used_tokens,max_tokens))
while dec_tokens > max_tokens:
Expand All @@ -185,13 +211,11 @@ def discard_exceed_conversation(session, max_tokens, total_tokens):
session.pop(1)
session.pop(1)
else:
break
break
dec_tokens = dec_tokens - max_tokens

@staticmethod
def clear_session(session_id):
all_sessions[session_id] = []
def clear_session(self, session_id):
self.sessions[session_id] = []

@staticmethod
def clear_all_session():
all_sessions.clear()
def clear_all_session(self):
self.sessions.clear()
47 changes: 26 additions & 21 deletions bot/openai/open_ai_bot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# encoding:utf-8

from bot.bot import Bot
from bridge.context import ContextType
from bridge.reply import Reply, ReplyType
from config import conf
from common.log import logger
import openai
Expand All @@ -18,29 +20,32 @@ def __init__(self):
if proxy:
openai.proxy = proxy


def reply(self, query, context=None):
# acquire reply content
if not context or not context.get('type') or context.get('type') == 'TEXT':
logger.info("[OPEN_AI] query={}".format(query))
from_user_id = context.get('from_user_id') or context.get('session_id')
if query == '#清除记忆':
Session.clear_session(from_user_id)
return '记忆已清除'
elif query == '#清除所有':
Session.clear_all_session()
return '所有人记忆已清除'

new_query = Session.build_session_query(query, from_user_id)
logger.debug("[OPEN_AI] session query={}".format(new_query))

reply_content = self.reply_text(new_query, from_user_id, 0)
logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content))
if reply_content and query:
Session.save_session(query, reply_content, from_user_id)
return reply_content

elif context.get('type', None) == 'IMAGE_CREATE':
return self.create_img(query, 0)
if context and context.type:
if context.type == ContextType.TEXT:
logger.info("[OPEN_AI] query={}".format(query))
from_user_id = context['session_id']
reply = None
if query == '#清除记忆':
Session.clear_session(from_user_id)
reply = Reply(ReplyType.INFO, '记忆已清除')
elif query == '#清除所有':
Session.clear_all_session()
reply = Reply(ReplyType.INFO, '所有人记忆已清除')
else:
new_query = Session.build_session_query(query, from_user_id)
logger.debug("[OPEN_AI] session query={}".format(new_query))

reply_content = self.reply_text(new_query, from_user_id, 0)
logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content))
if reply_content and query:
Session.save_session(query, reply_content, from_user_id)
reply = Reply(ReplyType.TEXT, reply_content)
return reply
elif context.type == ContextType.IMAGE_CREATE:
return self.create_img(query, 0)

def reply_text(self, query, user_id, retry_count=0):
try:
Expand Down
Loading

0 comments on commit 1f73f00

Please sign in to comment.