Skip to content

Commit

Permalink
feat: ChatGPT 支持流式输出,优化响应速度
Browse files Browse the repository at this point in the history
  • Loading branch information
wzpan committed Apr 8, 2023
1 parent 4a5524d commit 53e969e
Show file tree
Hide file tree
Showing 9 changed files with 362 additions and 133 deletions.
9 changes: 8 additions & 1 deletion plugins/Gossip.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

logger = logging.getLogger(__name__)

ENTRY_WORDS = ["进入", "打开", "激活", "开启", "一下"]
CLOSE_WORDS = ["退出", "结束", "停止"]

class Plugin(AbstractPlugin):

Expand All @@ -20,6 +22,11 @@ def handle(self, text, parsed):
else:
self.clearImmersive() # 去掉沉浸式
self.say("结束闲聊", cache=True)

def isValidImmersive(self, text, parsed):
return (
"闲聊" in text and any(word in text for word in CLOSE_WORDS)
)

def isValid(self, text, parsed):
return any(word in text.lower() for word in ["闲聊一下", "进入闲聊", "结束闲聊", "退出闲聊"])
return "闲聊" in text and any(word in text for word in ENTRY_WORDS)
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ jieba
pvporcupine
pvrecorder
openai
apscheduler
apscheduler
asyncio
129 changes: 97 additions & 32 deletions robot/AI.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
import os
import json
import random
import requests
Expand Down Expand Up @@ -28,6 +29,10 @@ def __init__(self, **kwargs):
def chat(self, texts, parsed):
pass

@abstractmethod
def stream_chat(self, texts):
pass


class TulingRobot(AbstractRobot):

Expand Down Expand Up @@ -189,14 +194,15 @@ def __init__(
):
"""
OpenAI机器人
openai.api_key = os.getenv("OPENAI_API_KEY")
"""
super(self.__class__, self).__init__()
self.openai = None
try:
import openai

self.openai = openai
if not openai_api_key:
openai_api_key = os.getenv("OPENAI_API_KEY")
self.openai.api_key = openai_api_key
if proxy:
logger.info(f"{self.SLUG} 使用代理:{proxy}")
Expand All @@ -220,6 +226,79 @@ def get_config(cls):
# Try to get anyq config from config
return config.get("openai", {})

def stream_chat(self, texts):
"""
从ChatGPT API获取回复
:return: 回复
"""

msg = "".join(texts)
msg = utils.stripPunctuation(msg)
msg = self.prefix + msg # 增加一段前缀
logger.info("msg: " + msg)
self.context.append({"role": "user", "content": msg})

header = {
"Content-Type": "application/json",
"Authorization": "Bearer " + self.openai.api_key,
}

data = {"model": "gpt-3.5-turbo", "messages": self.context, "stream": True}
logger.info("开始流式请求")
url = "https://api.openai.com/v1/chat/completions"
# 请求接收流式数据
try:
response = requests.request(
"POST",
url,
headers=header,
json=data,
stream=True,
proxies={"https": self.openai.proxy},
)

def generate():
stream_content = str()
one_message = {"role": "assistant", "content": stream_content}
self.context.append(one_message)
i = 0
for line in response.iter_lines():
line_str = str(line, encoding="utf-8")
if line_str.startswith("data:") and line_str[5:]:
if line_str.startswith("data: [DONE]"):
break
line_json = json.loads(line_str[5:])
if "choices" in line_json:
if len(line_json["choices"]) > 0:
choice = line_json["choices"][0]
if "delta" in choice:
delta = choice["delta"]
if "role" in delta:
role = delta["role"]
elif "content" in delta:
delta_content = delta["content"]
i += 1
if i < 40:
logger.debug(delta_content, end="")
elif i == 40:
logger.debug("......")
one_message["content"] = (
one_message["content"] + delta_content
)
yield delta_content

elif len(line_str.strip()) > 0:
logger.debug(line_str)
yield line_str

except Exception as e:
ee = e

def generate():
yield "request error:\n" + str(ee)

return generate

def chat(self, texts, parsed):
"""
使用OpenAI机器人聊天
Expand All @@ -233,37 +312,23 @@ def chat(self, texts, parsed):
logger.info("msg: " + msg)
try:
respond = ""
if "-turbo" in self.model:
self.context.append({"role": "user", "content": msg})
response = self.openai.Completion.create(
model=self.model,
messages=self.context,
temperature=self.temperature,
max_tokens=self.max_tokens,
top_p=self.top_p,
frequency_penalty=self.frequency_penalty,
presence_penalty=self.presence_penalty,
stop=self.stop_ai,
api_base=self.api_base
if self.api_base
else "https://api.openai.com/v1/chat",
)
message = response.choices[0].message
respond = message.content
self.context.append(message)
else:
response = self.openai.Completion.create(
model=self.model,
prompt=msg,
temperature=self.temperature,
max_tokens=self.max_tokens,
top_p=self.top_p,
frequency_penalty=self.frequency_penalty,
presence_penalty=self.presence_penalty,
stop=self.stop_ai,
)
respond = response.choices[0].text
logger.info(f"openai response: {respond}")
self.context.append({"role": "user", "content": msg})
response = self.openai.Completion.create(
model=self.model,
messages=self.context,
temperature=self.temperature,
max_tokens=self.max_tokens,
top_p=self.top_p,
frequency_penalty=self.frequency_penalty,
presence_penalty=self.presence_penalty,
stop=self.stop_ai,
api_base=self.api_base
if self.api_base
else "https://api.openai.com/v1/chat",
)
message = response.choices[0].message
respond = message.content
self.context.append(message)
return respond
except self.openai.error.InvalidRequestError:
logger.warning("token超出长度限制,丢弃历史会话")
Expand Down
Loading

0 comments on commit 53e969e

Please sign in to comment.