Skip to content

Commit

Permalink
Merge pull request zhayujie#2311 from cmgzn/master
Browse files Browse the repository at this point in the history
fix: gemini doesn't receive system messages...
  • Loading branch information
6vision authored Sep 26, 2024
2 parents 6af1930 + 0bf17f0 commit 6101c67
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions bot/gemini/google_gemini_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from bridge.reply import Reply, ReplyType
from common.log import logger
from config import conf
from bot.chatgpt.chat_gpt_session import ChatGPTSession
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
from google.generativeai.types import HarmCategory, HarmBlockThreshold

Expand All @@ -23,8 +24,8 @@ class GoogleGeminiBot(Bot):
def __init__(self):
super().__init__()
self.api_key = conf().get("gemini_api_key")
# 复用文心的token计算方式
self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("model") or "gpt-3.5-turbo")
# 复用chatGPT的token计算方式
self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo")
self.model = conf().get("model") or "gemini-pro"
if self.model == "gemini":
self.model = "gemini-pro"
Expand All @@ -37,6 +38,7 @@ def reply(self, query, context: Context = None) -> Reply:
session_id = context["session_id"]
session = self.sessions.session_query(query, session_id)
gemini_messages = self._convert_to_gemini_messages(self.filter_messages(session.messages))
logger.debug(f"[Gemini] messages={gemini_messages}")
genai.configure(api_key=self.api_key)
model = genai.GenerativeModel(self.model)

Expand Down Expand Up @@ -81,6 +83,8 @@ def _convert_to_gemini_messages(self, messages: list):
role = "user"
elif msg.get("role") == "assistant":
role = "model"
elif msg.get("role") == "system":
role = "user"
else:
continue
res.append({
Expand All @@ -97,7 +101,11 @@ def filter_messages(messages: list):
return res
for i in range(len(messages) - 1, -1, -1):
message = messages[i]
if message.get("role") != turn:
role = message.get("role")
if role == "system":
res.insert(0, message)
continue
if role != turn:
continue
res.insert(0, message)
if turn == "user":
Expand Down

0 comments on commit 6101c67

Please sign in to comment.