Skip to content

Commit

Permalink
Merge branch 'master' of github.com:zhayujie/chatgpt-on-wechat
Browse files Browse the repository at this point in the history
  • Loading branch information
zhayujie committed Dec 15, 2023
2 parents e04a12a + b4dc382 commit cc1b14b
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 4 deletions.
5 changes: 5 additions & 0 deletions bot/bot_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,9 @@ def create_bot(bot_type):
elif bot_type == const.QWEN:
from bot.ali.ali_qwen_bot import AliQwenBot
return AliQwenBot()

elif bot_type == const.GEMINI:
from bot.gemini.google_gemini_bot import GoogleGeminiBot
return GoogleGeminiBot()

raise RuntimeError
2 changes: 1 addition & 1 deletion bot/chatgpt/chat_gpt_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def calc_tokens(self):
def num_tokens_from_messages(messages, model):
"""Returns the number of tokens used by a list of messages."""

if model in ["wenxin", "xunfei"]:
if model in ["wenxin", "xunfei", const.GEMINI]:
return num_tokens_by_character(messages)

import tiktoken
Expand Down
75 changes: 75 additions & 0 deletions bot/gemini/google_gemini_bot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""
Google gemini bot
@author zhayujie
@Date 2023/12/15
"""
# encoding:utf-8

from bot.bot import Bot
import google.generativeai as genai
from bot.session_manager import SessionManager
from bridge.context import ContextType, Context
from bridge.reply import Reply, ReplyType
from common.log import logger
from config import conf
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession


# OpenAI对话模型API (可用)
class GoogleGeminiBot(Bot):

def __init__(self):
super().__init__()
self.api_key = conf().get("gemini_api_key")
# 复用文心的token计算方式
self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("model") or "gpt-3.5-turbo")

def reply(self, query, context: Context = None) -> Reply:
try:
if context.type != ContextType.TEXT:
logger.warn(f"[Gemini] Unsupported message type, type={context.type}")
return Reply(ReplyType.TEXT, None)
logger.info(f"[Gemini] query={query}")
session_id = context["session_id"]
session = self.sessions.session_query(query, session_id)
gemini_messages = self._convert_to_gemini_messages(self._filter_messages(session.messages))
genai.configure(api_key=self.api_key)
model = genai.GenerativeModel('gemini-pro')
response = model.generate_content(gemini_messages)
reply_text = response.text
self.sessions.session_reply(reply_text, session_id)
logger.info(f"[Gemini] reply={reply_text}")
return Reply(ReplyType.TEXT, reply_text)
except Exception as e:
logger.error("[Gemini] fetch reply error, may contain unsafe content")
logger.error(e)

def _convert_to_gemini_messages(self, messages: list):
res = []
for msg in messages:
if msg.get("role") == "user":
role = "user"
elif msg.get("role") == "assistant":
role = "model"
else:
continue
res.append({
"role": role,
"parts": [{"text": msg.get("content")}]
})
return res

def _filter_messages(self, messages: list):
res = []
turn = "user"
for i in range(len(messages) - 1, -1, -1):
message = messages[i]
if message.get("role") != turn:
continue
res.insert(0, message)
if turn == "user":
turn = "assistant"
elif turn == "assistant":
turn = "user"
return res
4 changes: 4 additions & 0 deletions bridge/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,16 @@ def __init__(self):
self.btype["chat"] = const.XUNFEI
if model_type in [const.QWEN]:
self.btype["chat"] = const.QWEN
if model_type in [const.GEMINI]:
self.btype["chat"] = const.GEMINI

if conf().get("use_linkai") and conf().get("linkai_api_key"):
self.btype["chat"] = const.LINKAI
if not conf().get("voice_to_text") or conf().get("voice_to_text") in ["openai"]:
self.btype["voice_to_text"] = const.LINKAI
if not conf().get("text_to_voice") or conf().get("text_to_voice") in ["openai", const.TTS_1, const.TTS_1_HD]:
self.btype["text_to_voice"] = const.LINKAI

if model_type in ["claude"]:
self.btype["chat"] = const.CLAUDEAI
self.bots = {}
Expand Down
3 changes: 2 additions & 1 deletion common/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
LINKAI = "linkai"
CLAUDEAI = "claude"
QWEN = "qwen"
GEMINI = "gemini"

# model
GPT35 = "gpt-3.5-turbo"
Expand All @@ -17,7 +18,7 @@
TTS_1 = "tts-1"
TTS_1_HD = "tts-1-hd"

MODEL_LIST = ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "wenxin", "wenxin-4", "xunfei", "claude", "gpt-4-turbo", GPT4_TURBO_PREVIEW, QWEN]
MODEL_LIST = ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "wenxin", "wenxin-4", "xunfei", "claude", "gpt-4-turbo", GPT4_TURBO_PREVIEW, QWEN, GEMINI]

# channel
FEISHU = "feishu"
2 changes: 2 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@
"qwen_agent_key": "",
"qwen_app_id": "",
"qwen_node_id": "", # 流程编排模型用到的id,如果没有用到qwen_node_id,请务必保持为空字符串
# Google Gemini Api Key
"gemini_api_key": "",
# wework的通用配置
"wework_smart": True, # 配置wework是否使用已登录的企业微信,False为多开
# 语音设置
Expand Down
4 changes: 2 additions & 2 deletions plugins/godcmd/godcmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def on_handle_context(self, e_context: EventContext):
except Exception as e:
ok, result = False, "你没有设置私有GPT模型"
elif cmd == "reset":
if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI, const.BAIDU, const.XUNFEI, const.QWEN]:
if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI, const.BAIDU, const.XUNFEI, const.QWEN, const.GEMINI]:
bot.sessions.clear_session(session_id)
if Bridge().chat_bots.get(bottype):
Bridge().chat_bots.get(bottype).sessions.clear_session(session_id)
Expand All @@ -339,7 +339,7 @@ def on_handle_context(self, e_context: EventContext):
ok, result = True, "配置已重载"
elif cmd == "resetall":
if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI,
const.BAIDU, const.XUNFEI, const.QWEN]:
const.BAIDU, const.XUNFEI, const.QWEN, const.GEMINI]:
channel.cancel_all_session()
bot.sessions.clear_all_session()
ok, result = True, "重置所有会话成功"
Expand Down

0 comments on commit cc1b14b

Please sign in to comment.