diff --git a/.flake8 b/.flake8
index c614a6302..0af19c5f1 100644
--- a/.flake8
+++ b/.flake8
@@ -1,5 +1,5 @@
[flake8]
-max-line-length = 88
+max-line-length = 176
select = E303,W293,W291,W292,E305,E231,E302
exclude =
.tox,
diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md
deleted file mode 100644
index 7d697ec53..000000000
--- a/.github/ISSUE_TEMPLATE.md
+++ /dev/null
@@ -1,31 +0,0 @@
-### 前置确认
-
-1. 网络能够访问openai接口
-2. python 已安装:版本在 3.7 ~ 3.10 之间
-3. `git pull` 拉取最新代码
-4. 执行`pip3 install -r requirements.txt`,检查依赖是否满足
-5. 拓展功能请执行`pip3 install -r requirements-optional.txt`,检查依赖是否满足
-6. 在已有 issue 中未搜索到类似问题
-7. [FAQS](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs) 中无类似问题
-
-
-### 问题描述
-
-> 简要说明、截图、复现步骤等,也可以是需求或想法
-
-
-
-
-### 终端日志 (如有报错)
-
-```
-[在此处粘贴终端日志, 可在主目录下`run.log`文件中找到]
-```
-
-
-
-### 环境
-
- - 操作系统类型 (Mac/Windows/Linux):
- - Python版本 ( 执行 `python3 -V` ):
- - pip版本 ( 依赖问题此项必填,执行 `pip3 -V`):
diff --git a/.github/ISSUE_TEMPLATE/1.bug.yml b/.github/ISSUE_TEMPLATE/1.bug.yml
new file mode 100644
index 000000000..2f762c076
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/1.bug.yml
@@ -0,0 +1,133 @@
+name: Bug report 🐛
+description: 项目运行中遇到的Bug或问题。
+labels: ['status: needs check']
+body:
+ - type: markdown
+ attributes:
+ value: |
+ ### ⚠️ 前置确认
+ 1. 网络能够访问openai接口
+ 2. python 已安装:版本在 3.7 ~ 3.10 之间
+ 3. `git pull` 拉取最新代码
+ 4. 执行`pip3 install -r requirements.txt`,检查依赖是否满足
+ 5. 拓展功能请执行`pip3 install -r requirements-optional.txt`,检查依赖是否满足
+ 6. [FAQS](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs) 中无类似问题
+ - type: checkboxes
+ attributes:
+ label: 前置确认
+ options:
+ - label: 我确认我运行的是最新版本的代码,并且安装了所需的依赖,在[FAQS](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs)中也未找到类似问题。
+ required: true
+ - type: checkboxes
+ attributes:
+ label: ⚠️ 搜索issues中是否已存在类似问题
+ description: >
+ 请在 [历史issue](https://github.com/zhayujie/chatgpt-on-wechat/issues) 中清空输入框,搜索你的问题
+ 或相关日志的关键词来查找是否存在类似问题。
+ options:
+ - label: 我已经搜索过issues和disscussions,没有跟我遇到的问题相关的issue
+ required: true
+ - type: markdown
+ attributes:
+ value: |
+ 请在上方的`title`中填写你对你所遇到问题的简略总结,这将帮助其他人更好的找到相似问题,谢谢❤️。
+ - type: dropdown
+ attributes:
+ label: 操作系统类型?
+ description: >
+ 请选择你运行程序的操作系统类型。
+ options:
+ - Windows
+ - Linux
+ - MacOS
+ - Docker
+ - Railway
+ - Windows Subsystem for Linux (WSL)
+ - Other (请在问题中说明)
+ validations:
+ required: true
+ - type: dropdown
+ attributes:
+ label: 运行的python版本是?
+ description: |
+ 请选择你运行程序的`python`版本。
+ 注意:在`python 3.7`中,有部分可选依赖无法安装。
+ 经过长时间的观察,我们认为`python 3.8`是兼容性最好的版本。
+ `python 3.7`~`python 3.10`以外版本的issue,将视情况直接关闭。
+ options:
+ - python 3.7
+ - python 3.8
+ - python 3.9
+ - python 3.10
+ - other
+ validations:
+ required: true
+ - type: dropdown
+ attributes:
+ label: 使用的chatgpt-on-wechat版本是?
+ description: |
+ 请确保你使用的是 [releases](https://github.com/zhayujie/chatgpt-on-wechat/releases) 中的最新版本。
+ 如果你使用git, 请使用`git branch`命令来查看分支。
+ options:
+ - Latest Release
+ - Master (branch)
+ validations:
+ required: true
+ - type: dropdown
+ attributes:
+ label: 运行的`channel`类型是?
+ description: |
+ 请确保你正确配置了该`channel`所需的配置项,所有可选的配置项都写在了[该文件中](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py),请将所需配置项填写在根目录下的`config.json`文件中。
+ options:
+ - wx(个人微信, itchat)
+ - wxy(个人微信, wechaty)
+ - wechatmp(公众号, 订阅号)
+ - wechatmp_service(公众号, 服务号)
+ - terminal
+ - other
+ validations:
+ required: true
+ - type: textarea
+ attributes:
+ label: 复现步骤 🕹
+ description: |
+ **⚠️ 不能复现将会关闭issue.**
+ - type: textarea
+ attributes:
+ label: 问题描述 😯
+ description: 详细描述出现的问题,或提供有关截图。
+ - type: textarea
+ attributes:
+ label: 终端日志 📒
+ description: |
+ 在此处粘贴终端日志,可在主目录下`run.log`文件中找到,这会帮助我们更好的分析问题,注意隐去你的API key。
+ 如果在配置文件中加入`"debug": true`,打印出的日志会更有帮助。
+
+
+ 示例
+ ```log
+ [DEBUG][2023-04-16 00:23:22][plugin_manager.py:157] - Plugin SUMMARY triggered by event Event.ON_HANDLE_CONTEXT
+ [DEBUG][2023-04-16 00:23:22][main.py:221] - [Summary] on_handle_context. content: $总结前100条消息
+ [DEBUG][2023-04-16 00:23:24][main.py:240] - [Summary] limit: 100, duration: -1 seconds
+ [ERROR][2023-04-16 00:23:24][chat_channel.py:244] - Worker return exception: name 'start_date' is not defined
+ Traceback (most recent call last):
+ File "C:\ProgramData\Anaconda3\lib\concurrent\futures\thread.py", line 57, in run
+ result = self.fn(*self.args, **self.kwargs)
+ File "D:\project\chatgpt-on-wechat\channel\chat_channel.py", line 132, in _handle
+ reply = self._generate_reply(context)
+ File "D:\project\chatgpt-on-wechat\channel\chat_channel.py", line 142, in _generate_reply
+ e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {
+ File "D:\project\chatgpt-on-wechat\plugins\plugin_manager.py", line 159, in emit_event
+ instance.handlers[e_context.event](e_context, *args, **kwargs)
+ File "D:\project\chatgpt-on-wechat\plugins\summary\main.py", line 255, in on_handle_context
+ records = self._get_records(session_id, start_time, limit)
+ File "D:\project\chatgpt-on-wechat\plugins\summary\main.py", line 96, in _get_records
+ c.execute("SELECT * FROM chat_records WHERE sessionid=? and timestamp>? ORDER BY timestamp DESC LIMIT ?", (session_id, start_date, limit))
+ NameError: name 'start_date' is not defined
+ [INFO][2023-04-16 00:23:36][app.py:14] - signal 2 received, exiting...
+ ```
+
+ value: |
+ ```log
+ <此处粘贴终端日志>
+ ```
\ No newline at end of file
diff --git a/.github/ISSUE_TEMPLATE/2.feature.yml b/.github/ISSUE_TEMPLATE/2.feature.yml
new file mode 100644
index 000000000..bbf0888a2
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/2.feature.yml
@@ -0,0 +1,28 @@
+name: Feature request 🚀
+description: 提出你对项目的新想法或建议。
+labels: ['status: needs check']
+body:
+ - type: markdown
+ attributes:
+ value: |
+ 请在上方的`title`中填写简略总结,谢谢❤️。
+ - type: checkboxes
+ attributes:
+ label: ⚠️ 搜索是否存在类似issue
+ description: >
+ 请在 [历史issue](https://github.com/zhayujie/chatgpt-on-wechat/issues) 中清空输入框,搜索关键词查找是否存在相似issue。
+ options:
+ - label: 我已经搜索过issues和disscussions,没有发现相似issue
+ required: true
+ - type: textarea
+ attributes:
+ label: 总结
+ description: 描述feature的功能。
+ - type: textarea
+ attributes:
+ label: 举例
+ description: 提供聊天示例,草图或相关网址。
+ - type: textarea
+ attributes:
+ label: 动机
+ description: 描述你提出该feature的动机,比如没有这项feature对你的使用造成了怎样的影响。 请提供更详细的场景描述,这可能会帮助我们发现并提出更好的解决方案。
\ No newline at end of file
diff --git a/README.md b/README.md
index bb86b4130..2fad81bc0 100644
--- a/README.md
+++ b/README.md
@@ -2,27 +2,29 @@
> ChatGPT近期以强大的对话和信息整合能力风靡全网,可以写代码、改论文、讲故事,几乎无所不能,这让人不禁有个大胆的想法,能否用他的对话模型把我们的微信打造成一个智能机器人,可以在与好友对话中给出意想不到的回应,而且再也不用担心女朋友影响我们 ~~打游戏~~ 工作了。
+最新版本支持的功能如下:
-基于ChatGPT的微信聊天机器人,通过 [ChatGPT](https://github.com/openai/openai-python) 接口生成对话内容,使用 [itchat](https://github.com/littlecodersh/ItChat) 实现微信消息的接收和自动回复。已实现的特性如下:
-
-- [x] **文本对话:** 接收私聊及群组中的微信消息,使用ChatGPT生成回复内容,完成自动回复
-- [x] **规则定制化:** 支持私聊中按指定规则触发自动回复,支持对群组设置自动回复白名单
-- [x] **图片生成:** 支持根据描述生成图片,支持图片修复
-- [x] **上下文记忆**:支持多轮对话记忆,且为每个好友维护独立的上下会话
-- [x] **语音识别:** 支持接收和处理语音消息,通过文字或语音回复
-- [x] **插件化:** 支持个性化插件,提供角色扮演、文字冒险、与操作系统交互、访问网络数据等能力
-
-> 目前支持微信和微信公众号部署,欢迎接入更多应用,参考 [Terminal代码](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/terminal/terminal_channel.py)实现接收和发送消息逻辑即可接入。 同时欢迎增加新的插件,参考 [插件说明文档](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins)。
+- [x] **基础对话:** 私聊及群聊的消息智能回复,支持多轮会话上下文记忆,支持 GPT-3,GPT-3.5,GPT-4模型
+- [x] **语音识别:** 可识别语音消息,通过文字或语音回复,支持 azure, baidu, google, openai等多种语音模型
+- [x] **图片生成:** 支持图片生成 和 照片修复,可选择 Dell-E, stable diffusion, replicate 模型
+- [x] **丰富插件:** 支持个性化插件扩展,已实现多角色切换、文字冒险、敏感词过滤、聊天记录总结等插件
+- [X] **Tool工具:** 与操作系统和互联网交互,支持最新信息搜索、数学计算、天气和资讯查询、网页总结,基于 [chatgpt-tool-hub](https://github.com/goldfishh/chatgpt-tool-hub) 实现
+> 目前已支持 个人微信 和 微信公众号,欢迎接入更多应用,参考 [Terminal代码](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/terminal/terminal_channel.py)实现接收和发送消息逻辑即可接入。 同时欢迎增加新的插件,参考 [插件说明文档](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins)。
**一键部署:**
[![Deploy on Railway](https://railway.app/button.svg)](https://railway.app/template/qApznZ?referralCode=RC3znh)
+# 演示
+
+https://user-images.githubusercontent.com/26161723/233777277-e3b9928e-b88f-43e2-b0e0-3cbc923bc799.mp4
+
+Demo made by [Visionn](https://www.wangpc.cc/)
# 更新日志
->**2023.04.05:** 支持微信个人号部署,兼容角色扮演等预设插件,[使用文档](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/wechatmp/README.md)。(contributed by [@JS00000](https://github.com/JS00000) in [#686](https://github.com/zhayujie/chatgpt-on-wechat/pull/686))
+>**2023.04.05:** 支持微信公众号部署,兼容角色扮演等预设插件,[使用文档](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/wechatmp/README.md)。(contributed by [@JS00000](https://github.com/JS00000) in [#686](https://github.com/zhayujie/chatgpt-on-wechat/pull/686))
>**2023.04.05:** 增加能让ChatGPT使用工具的`tool`插件,[使用文档](https://github.com/goldfishh/chatgpt-on-wechat/blob/master/plugins/tool/README.md)。工具相关issue可反馈至[chatgpt-tool-hub](https://github.com/goldfishh/chatgpt-tool-hub)。(contributed by [@goldfishh](https://github.com/goldfishh) in [#663](https://github.com/zhayujie/chatgpt-on-wechat/pull/663))
@@ -32,28 +34,7 @@
>**2023.03.02:** 接入[ChatGPT API](https://platform.openai.com/docs/guides/chat) (gpt-3.5-turbo),默认使用该模型进行对话,需升级openai依赖 (`pip3 install --upgrade openai`)。网络问题参考 [#351](https://github.com/zhayujie/chatgpt-on-wechat/issues/351)
->**2023.02.09:** 扫码登录存在封号风险,请谨慎使用,参考[#58](https://github.com/AutumnWhj/ChatGPT-wechat-bot/issues/158)
-
->**2023.02.05:** 在openai官方接口方案中 (GPT-3模型) 实现上下文对话
-
->**2022.12.18:** 支持根据描述生成图片并发送,openai版本需大于0.25.0
-
->**2022.12.17:** 原来的方案是从 [ChatGPT页面](https://chat.openai.com/chat) 获取session_token,使用 [revChatGPT](https://github.com/acheong08/ChatGPT) 直接访问web接口,但随着ChatGPT接入Cloudflare人机验证,这一方案难以在服务器顺利运行。 所以目前使用的方案是调用 OpenAI 官方提供的 [API](https://beta.openai.com/docs/api-reference/introduction),回复质量上基本接近于ChatGPT的内容,劣势是暂不支持有上下文记忆的对话,优势是稳定性和响应速度较好。
-
-# 使用效果
-
-### 个人聊天
-
-![single-chat-sample.jpg](docs/images/single-chat-sample.jpg)
-
-### 群组聊天
-
-![group-chat-sample.jpg](docs/images/group-chat-sample.jpg)
-
-### 图片生成
-
-![group-chat-sample.jpg](docs/images/image-create-sample.jpg)
-
+>**2023.02.09:** 扫码登录存在账号限制风险,请谨慎使用,参考[#58](https://github.com/AutumnWhj/ChatGPT-wechat-bot/issues/158)
# 快速开始
@@ -63,7 +44,7 @@
前往 [OpenAI注册页面](https://beta.openai.com/signup) 创建账号,参考这篇 [教程](https://www.pythonthree.com/register-openai-chatgpt/) 可以通过虚拟手机号来接收验证码。创建完账号则前往 [API管理页面](https://beta.openai.com/account/api-keys) 创建一个 API Key 并保存下来,后面需要在项目中配置这个key。
-> 项目中使用的对话模型是 davinci,计费方式是约每 750 字 (包含请求和回复) 消耗 $0.02,图片生成是每张消耗 $0.016,账号创建有免费的 $18 额度 (更新3.25: 最新注册的已经无免费额度了),使用完可以更换邮箱重新注册。
+> 项目中默认使用的对话模型是 gpt3.5 turbo,计费方式是约每 500 汉字 (包含请求和回复) 消耗 $0.002,图片生成是每张消耗 $0.016。
### 2.运行环境
@@ -203,7 +184,7 @@ nohup python3 app.py & tail -f nohup.out # 在后台运行程序并通
参考文档 [Docker部署](https://github.com/limccn/chatgpt-on-wechat/wiki/Docker%E9%83%A8%E7%BD%B2) (Contributed by [limccn](https://github.com/limccn))。
-### 4. Railway部署(✅推荐)
+### 4. Railway部署 (✅推荐)
> Railway每月提供5刀和最多500小时的免费额度。
1. 进入 [Railway](https://railway.app/template/qApznZ?referralCode=RC3znh)。
2. 点击 `Deploy Now` 按钮。
diff --git a/app.py b/app.py
index e11f46c2a..637b6e462 100644
--- a/app.py
+++ b/app.py
@@ -19,6 +19,7 @@ def func(_signo, _stack_frame):
if callable(old_handler): # check old_handler
return old_handler(_signo, _stack_frame)
sys.exit(0)
+
signal.signal(_signo, func)
diff --git a/bot/baidu/baidu_unit_bot.py b/bot/baidu/baidu_unit_bot.py
index d8a0aca11..f7714e4f4 100644
--- a/bot/baidu/baidu_unit_bot.py
+++ b/bot/baidu/baidu_unit_bot.py
@@ -10,10 +10,7 @@
class BaiduUnitBot(Bot):
def reply(self, query, context=None):
token = self.get_token()
- url = (
- "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token="
- + token
- )
+ url = "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" + token
post_data = (
'{"version":"3.0","service_id":"S73177","session_id":"","log_id":"7758521","skill_ids":["1221886"],"request":{"terminal_id":"88888","query":"'
+ query
@@ -32,12 +29,7 @@ def reply(self, query, context=None):
def get_token(self):
access_key = "YOUR_ACCESS_KEY"
secret_key = "YOUR_SECRET_KEY"
- host = (
- "https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id="
- + access_key
- + "&client_secret="
- + secret_key
- )
+ host = "https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=" + access_key + "&client_secret=" + secret_key
response = requests.get(host)
if response:
print(response.json())
diff --git a/bot/chatgpt/chat_gpt_bot.py b/bot/chatgpt/chat_gpt_bot.py
index d8e4b0e26..b045311a3 100644
--- a/bot/chatgpt/chat_gpt_bot.py
+++ b/bot/chatgpt/chat_gpt_bot.py
@@ -30,23 +30,15 @@ def __init__(self):
if conf().get("rate_limit_chatgpt"):
self.tb4chatgpt = TokenBucket(conf().get("rate_limit_chatgpt", 20))
- self.sessions = SessionManager(
- ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo"
- )
+ self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo")
self.args = {
"model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称
"temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
# "max_tokens":4096, # 回复最大的字符数
"top_p": 1,
- "frequency_penalty": conf().get(
- "frequency_penalty", 0.0
- ), # [-2,2]之间,该值越大则更倾向于产生不同的内容
- "presence_penalty": conf().get(
- "presence_penalty", 0.0
- ), # [-2,2]之间,该值越大则更倾向于产生不同的内容
- "request_timeout": conf().get(
- "request_timeout", None
- ), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
+ "frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
+ "presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
+ "request_timeout": conf().get("request_timeout", None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
"timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试
}
@@ -87,15 +79,10 @@ def reply(self, query, context=None):
reply_content["completion_tokens"],
)
)
- if (
- reply_content["completion_tokens"] == 0
- and len(reply_content["content"]) > 0
- ):
+ if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
reply = Reply(ReplyType.ERROR, reply_content["content"])
elif reply_content["completion_tokens"] > 0:
- self.sessions.session_reply(
- reply_content["content"], session_id, reply_content["total_tokens"]
- )
+ self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
reply = Reply(ReplyType.TEXT, reply_content["content"])
else:
reply = Reply(ReplyType.ERROR, reply_content["content"])
@@ -126,9 +113,7 @@ def reply_text(self, session: ChatGPTSession, api_key=None, retry_count=0) -> di
if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token():
raise openai.error.RateLimitError("RateLimitError: rate limit exceeded")
# if api_key == None, the default openai.api_key will be used
- response = openai.ChatCompletion.create(
- api_key=api_key, messages=session.messages, **self.args
- )
+ response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **self.args)
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
return {
"total_tokens": response["usage"]["total_tokens"],
diff --git a/bot/chatgpt/chat_gpt_session.py b/bot/chatgpt/chat_gpt_session.py
index 525793ffa..e6c319b36 100644
--- a/bot/chatgpt/chat_gpt_session.py
+++ b/bot/chatgpt/chat_gpt_session.py
@@ -25,9 +25,7 @@ def discard_exceeding(self, max_tokens, cur_tokens=None):
precise = False
if cur_tokens is None:
raise e
- logger.debug(
- "Exception when counting tokens precisely for query: {}".format(e)
- )
+ logger.debug("Exception when counting tokens precisely for query: {}".format(e))
while cur_tokens > max_tokens:
if len(self.messages) > 2:
self.messages.pop(1)
@@ -39,16 +37,10 @@ def discard_exceeding(self, max_tokens, cur_tokens=None):
cur_tokens = cur_tokens - max_tokens
break
elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
- logger.warn(
- "user message exceed max_tokens. total_tokens={}".format(cur_tokens)
- )
+ logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
break
else:
- logger.debug(
- "max_tokens={}, total_tokens={}, len(messages)={}".format(
- max_tokens, cur_tokens, len(self.messages)
- )
- )
+ logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages)))
break
if precise:
cur_tokens = self.calc_tokens()
@@ -75,17 +67,13 @@ def num_tokens_from_messages(messages, model):
elif model == "gpt-4":
return num_tokens_from_messages(messages, model="gpt-4-0314")
elif model == "gpt-3.5-turbo-0301":
- tokens_per_message = (
- 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
- )
+ tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
tokens_per_name = -1 # if there's a name, the role is omitted
elif model == "gpt-4-0314":
tokens_per_message = 3
tokens_per_name = 1
else:
- logger.warn(
- f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo-0301."
- )
+ logger.warn(f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo-0301.")
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301")
num_tokens = 0
for message in messages:
diff --git a/bot/openai/open_ai_bot.py b/bot/openai/open_ai_bot.py
index 1cfbf10d8..160562526 100644
--- a/bot/openai/open_ai_bot.py
+++ b/bot/openai/open_ai_bot.py
@@ -28,23 +28,15 @@ def __init__(self):
if proxy:
openai.proxy = proxy
- self.sessions = SessionManager(
- OpenAISession, model=conf().get("model") or "text-davinci-003"
- )
+ self.sessions = SessionManager(OpenAISession, model=conf().get("model") or "text-davinci-003")
self.args = {
"model": conf().get("model") or "text-davinci-003", # 对话模型的名称
"temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
"max_tokens": 1200, # 回复最大的字符数
"top_p": 1,
- "frequency_penalty": conf().get(
- "frequency_penalty", 0.0
- ), # [-2,2]之间,该值越大则更倾向于产生不同的内容
- "presence_penalty": conf().get(
- "presence_penalty", 0.0
- ), # [-2,2]之间,该值越大则更倾向于产生不同的内容
- "request_timeout": conf().get(
- "request_timeout", None
- ), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
+ "frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
+ "presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
+ "request_timeout": conf().get("request_timeout", None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
"timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试
"stop": ["\n\n\n"],
}
@@ -71,17 +63,13 @@ def reply(self, query, context=None):
result["content"],
)
logger.debug(
- "[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
- str(session), session_id, reply_content, completion_tokens
- )
+ "[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(str(session), session_id, reply_content, completion_tokens)
)
if total_tokens == 0:
reply = Reply(ReplyType.ERROR, reply_content)
else:
- self.sessions.session_reply(
- reply_content, session_id, total_tokens
- )
+ self.sessions.session_reply(reply_content, session_id, total_tokens)
reply = Reply(ReplyType.TEXT, reply_content)
return reply
elif context.type == ContextType.IMAGE_CREATE:
@@ -96,9 +84,7 @@ def reply(self, query, context=None):
def reply_text(self, session: OpenAISession, retry_count=0):
try:
response = openai.Completion.create(prompt=str(session), **self.args)
- res_content = (
- response.choices[0]["text"].strip().replace("<|endoftext|>", "")
- )
+ res_content = response.choices[0]["text"].strip().replace("<|endoftext|>", "")
total_tokens = response["usage"]["total_tokens"]
completion_tokens = response["usage"]["completion_tokens"]
logger.info("[OPEN_AI] reply={}".format(res_content))
diff --git a/bot/openai/open_ai_image.py b/bot/openai/open_ai_image.py
index 5dbbd23ed..b188557f3 100644
--- a/bot/openai/open_ai_image.py
+++ b/bot/openai/open_ai_image.py
@@ -23,9 +23,7 @@ def create_img(self, query, retry_count=0):
response = openai.Image.create(
prompt=query, # 图片描述
n=1, # 每次生成图片的数量
- size=conf().get(
- "image_create_size", "256x256"
- ), # 图片大小,可选有 256x256, 512x512, 1024x1024
+ size=conf().get("image_create_size", "256x256"), # 图片大小,可选有 256x256, 512x512, 1024x1024
)
image_url = response["data"][0]["url"]
logger.info("[OPEN_AI] image_url={}".format(image_url))
@@ -34,11 +32,7 @@ def create_img(self, query, retry_count=0):
logger.warn(e)
if retry_count < 1:
time.sleep(5)
- logger.warn(
- "[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(
- retry_count + 1
- )
- )
+ logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count + 1))
return self.create_img(query, retry_count + 1)
else:
return False, "提问太快啦,请休息一下再问我吧"
diff --git a/bot/openai/open_ai_session.py b/bot/openai/open_ai_session.py
index 78cf43900..8f6aa4f5b 100644
--- a/bot/openai/open_ai_session.py
+++ b/bot/openai/open_ai_session.py
@@ -36,9 +36,7 @@ def discard_exceeding(self, max_tokens, cur_tokens=None):
precise = False
if cur_tokens is None:
raise e
- logger.debug(
- "Exception when counting tokens precisely for query: {}".format(e)
- )
+ logger.debug("Exception when counting tokens precisely for query: {}".format(e))
while cur_tokens > max_tokens:
if len(self.messages) > 1:
self.messages.pop(0)
@@ -50,18 +48,10 @@ def discard_exceeding(self, max_tokens, cur_tokens=None):
cur_tokens = len(str(self))
break
elif len(self.messages) == 1 and self.messages[0]["role"] == "user":
- logger.warn(
- "user question exceed max_tokens. total_tokens={}".format(
- cur_tokens
- )
- )
+ logger.warn("user question exceed max_tokens. total_tokens={}".format(cur_tokens))
break
else:
- logger.debug(
- "max_tokens={}, total_tokens={}, len(conversation)={}".format(
- max_tokens, cur_tokens, len(self.messages)
- )
- )
+ logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.messages)))
break
if precise:
cur_tokens = self.calc_tokens()
diff --git a/bot/session_manager.py b/bot/session_manager.py
index 1aff647b5..8d70886e0 100644
--- a/bot/session_manager.py
+++ b/bot/session_manager.py
@@ -55,9 +55,7 @@ def build_session(self, session_id, system_prompt=None):
return self.sessioncls(session_id, system_prompt, **self.session_args)
if session_id not in self.sessions:
- self.sessions[session_id] = self.sessioncls(
- session_id, system_prompt, **self.session_args
- )
+ self.sessions[session_id] = self.sessioncls(session_id, system_prompt, **self.session_args)
elif system_prompt is not None: # 如果有新的system_prompt,更新并重置session
self.sessions[session_id].set_system_prompt(system_prompt)
session = self.sessions[session_id]
@@ -71,9 +69,7 @@ def session_query(self, query, session_id):
total_tokens = session.discard_exceeding(max_tokens, None)
logger.debug("prompt tokens used={}".format(total_tokens))
except Exception as e:
- logger.debug(
- "Exception when counting tokens precisely for prompt: {}".format(str(e))
- )
+ logger.debug("Exception when counting tokens precisely for prompt: {}".format(str(e)))
return session
def session_reply(self, reply, session_id, total_tokens=None):
@@ -82,17 +78,9 @@ def session_reply(self, reply, session_id, total_tokens=None):
try:
max_tokens = conf().get("conversation_max_tokens", 1000)
tokens_cnt = session.discard_exceeding(max_tokens, total_tokens)
- logger.debug(
- "raw total_tokens={}, savesession tokens={}".format(
- total_tokens, tokens_cnt
- )
- )
+ logger.debug("raw total_tokens={}, savesession tokens={}".format(total_tokens, tokens_cnt))
except Exception as e:
- logger.debug(
- "Exception when counting tokens precisely for session: {}".format(
- str(e)
- )
- )
+ logger.debug("Exception when counting tokens precisely for session: {}".format(str(e)))
return session
def clear_session(self, session_id):
diff --git a/bridge/bridge.py b/bridge/bridge.py
index dcf6e7e07..78fe23458 100644
--- a/bridge/bridge.py
+++ b/bridge/bridge.py
@@ -1,11 +1,12 @@
-from bot import bot_factory
+from bot.bot_factory import create_bot
from bridge.context import Context
from bridge.reply import Reply
from common import const
from common.log import logger
from common.singleton import singleton
from config import conf
-from voice import voice_factory
+from translate.factory import create_translator
+from voice.factory import create_voice
@singleton
@@ -15,6 +16,7 @@ def __init__(self):
"chat": const.CHATGPT,
"voice_to_text": conf().get("voice_to_text", "openai"),
"text_to_voice": conf().get("text_to_voice", "google"),
+ "translate": conf().get("translate", "baidu"),
}
model_type = conf().get("model")
if model_type in ["text-davinci-003"]:
@@ -27,11 +29,13 @@ def get_bot(self, typename):
if self.bots.get(typename) is None:
logger.info("create bot {} for {}".format(self.btype[typename], typename))
if typename == "text_to_voice":
- self.bots[typename] = voice_factory.create_voice(self.btype[typename])
+ self.bots[typename] = create_voice(self.btype[typename])
elif typename == "voice_to_text":
- self.bots[typename] = voice_factory.create_voice(self.btype[typename])
+ self.bots[typename] = create_voice(self.btype[typename])
elif typename == "chat":
- self.bots[typename] = bot_factory.create_bot(self.btype[typename])
+ self.bots[typename] = create_bot(self.btype[typename])
+ elif typename == "translate":
+ self.bots[typename] = create_translator(self.btype[typename])
return self.bots[typename]
def get_bot_type(self, typename):
@@ -45,3 +49,6 @@ def fetch_voice_to_text(self, voiceFile) -> Reply:
def fetch_text_to_voice(self, text) -> Reply:
return self.get_bot("text_to_voice").textToVoice(text)
+
+ def fetch_translate(self, text, from_lang="", to_lang="en") -> Reply:
+ return self.get_bot("translate").translate(text, from_lang, to_lang)
diff --git a/bridge/context.py b/bridge/context.py
index c1eb10c1f..ab004c0f6 100644
--- a/bridge/context.py
+++ b/bridge/context.py
@@ -60,6 +60,4 @@ def __delitem__(self, key):
del self.kwargs[key]
def __str__(self):
- return "Context(type={}, content={}, kwargs={})".format(
- self.type, self.content, self.kwargs
- )
+ return "Context(type={}, content={}, kwargs={})".format(self.type, self.content, self.kwargs)
diff --git a/channel/chat_channel.py b/channel/chat_channel.py
index 89dfe145a..d57cdb569 100644
--- a/channel/chat_channel.py
+++ b/channel/chat_channel.py
@@ -53,9 +53,7 @@ def _compose_context(self, ctype: ContextType, content, **kwargs):
group_id = cmsg.other_user_id
group_name_white_list = config.get("group_name_white_list", [])
- group_name_keyword_white_list = config.get(
- "group_name_keyword_white_list", []
- )
+ group_name_keyword_white_list = config.get("group_name_keyword_white_list", [])
if any(
[
group_name in group_name_white_list,
@@ -63,9 +61,7 @@ def _compose_context(self, ctype: ContextType, content, **kwargs):
check_contain(group_name, group_name_keyword_white_list),
]
):
- group_chat_in_one_session = conf().get(
- "group_chat_in_one_session", []
- )
+ group_chat_in_one_session = conf().get("group_chat_in_one_session", [])
session_id = cmsg.actual_user_id
if any(
[
@@ -81,17 +77,11 @@ def _compose_context(self, ctype: ContextType, content, **kwargs):
else:
context["session_id"] = cmsg.other_user_id
context["receiver"] = cmsg.other_user_id
- e_context = PluginManager().emit_event(
- EventContext(
- Event.ON_RECEIVE_MESSAGE, {"channel": self, "context": context}
- )
- )
+ e_context = PluginManager().emit_event(EventContext(Event.ON_RECEIVE_MESSAGE, {"channel": self, "context": context}))
context = e_context["context"]
if e_context.is_pass() or context is None:
return context
- if cmsg.from_user_id == self.user_id and not config.get(
- "trigger_by_self", True
- ):
+ if cmsg.from_user_id == self.user_id and not config.get("trigger_by_self", True):
logger.debug("[WX]self message skipped")
return None
@@ -114,28 +104,22 @@ def _compose_context(self, ctype: ContextType, content, **kwargs):
logger.info("[WX]receive group at")
if not conf().get("group_at_off", False):
flag = True
- pattern = f"@{self.name}(\u2005|\u0020)"
+ pattern = f"@{re.escape(self.name)}(\u2005|\u0020)"
content = re.sub(pattern, r"", content)
if not flag:
if context["origin_ctype"] == ContextType.VOICE:
- logger.info(
- "[WX]receive group voice, but checkprefix didn't match"
- )
+ logger.info("[WX]receive group voice, but checkprefix didn't match")
return None
else: # 单聊
- match_prefix = check_prefix(
- content, conf().get("single_chat_prefix", [""])
- )
+ match_prefix = check_prefix(content, conf().get("single_chat_prefix", [""]))
if match_prefix is not None: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容
content = content.replace(match_prefix, "", 1).strip()
- elif (
- context["origin_ctype"] == ContextType.VOICE
- ): # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件
+ elif context["origin_ctype"] == ContextType.VOICE: # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件
pass
else:
return None
-
+ content = content.strip()
img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
if img_match_prefix:
content = content.replace(img_match_prefix, "", 1)
@@ -143,18 +127,10 @@ def _compose_context(self, ctype: ContextType, content, **kwargs):
else:
context.type = ContextType.TEXT
context.content = content.strip()
- if (
- "desire_rtype" not in context
- and conf().get("always_reply_voice")
- and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE
- ):
+ if "desire_rtype" not in context and conf().get("always_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
context["desire_rtype"] = ReplyType.VOICE
elif context.type == ContextType.VOICE:
- if (
- "desire_rtype" not in context
- and conf().get("voice_reply_voice")
- and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE
- ):
+ if "desire_rtype" not in context and conf().get("voice_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
context["desire_rtype"] = ReplyType.VOICE
return context
@@ -182,15 +158,8 @@ def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply:
)
reply = e_context["reply"]
if not e_context.is_pass():
- logger.debug(
- "[WX] ready to handle context: type={}, content={}".format(
- context.type, context.content
- )
- )
- if (
- context.type == ContextType.TEXT
- or context.type == ContextType.IMAGE_CREATE
- ): # 文字和图片消息
+ logger.debug("[WX] ready to handle context: type={}, content={}".format(context.type, context.content))
+ if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息
reply = super().build_reply_content(context.content, context)
elif context.type == ContextType.VOICE: # 语音消息
cmsg = context["msg"]
@@ -214,9 +183,7 @@ def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply:
# logger.warning("[WX]delete temp file error: " + str(e))
if reply.type == ReplyType.TEXT:
- new_context = self._compose_context(
- ContextType.TEXT, reply.content, **context.kwargs
- )
+ new_context = self._compose_context(ContextType.TEXT, reply.content, **context.kwargs)
if new_context:
reply = self._generate_reply(new_context)
else:
@@ -246,48 +213,24 @@ def _decorate_reply(self, context: Context, reply: Reply) -> Reply:
if reply.type == ReplyType.TEXT:
reply_text = reply.content
- if (
- desire_rtype == ReplyType.VOICE
- and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE
- ):
+ if desire_rtype == ReplyType.VOICE and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
reply = super().build_text_to_voice(reply.content)
return self._decorate_reply(context, reply)
if context.get("isgroup", False):
- reply_text = (
- "@"
- + context["msg"].actual_user_nickname
- + " "
- + reply_text.strip()
- )
- reply_text = (
- conf().get("group_chat_reply_prefix", "") + reply_text
- )
+ reply_text = "@" + context["msg"].actual_user_nickname + " " + reply_text.strip()
+ reply_text = conf().get("group_chat_reply_prefix", "") + reply_text
else:
- reply_text = (
- conf().get("single_chat_reply_prefix", "") + reply_text
- )
+ reply_text = conf().get("single_chat_reply_prefix", "") + reply_text
reply.content = reply_text
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
reply.content = "[" + str(reply.type) + "]\n" + reply.content
- elif (
- reply.type == ReplyType.IMAGE_URL
- or reply.type == ReplyType.VOICE
- or reply.type == ReplyType.IMAGE
- ):
+ elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE:
pass
else:
logger.error("[WX] unknown reply type: {}".format(reply.type))
return
- if (
- desire_rtype
- and desire_rtype != reply.type
- and reply.type not in [ReplyType.ERROR, ReplyType.INFO]
- ):
- logger.warning(
- "[WX] desire_rtype: {}, but reply type: {}".format(
- context.get("desire_rtype"), reply.type
- )
- )
+ if desire_rtype and desire_rtype != reply.type and reply.type not in [ReplyType.ERROR, ReplyType.INFO]:
+ logger.warning("[WX] desire_rtype: {}, but reply type: {}".format(context.get("desire_rtype"), reply.type))
return reply
def _send_reply(self, context: Context, reply: Reply):
@@ -300,9 +243,7 @@ def _send_reply(self, context: Context, reply: Reply):
)
reply = e_context["reply"]
if not e_context.is_pass() and reply and reply.type:
- logger.debug(
- "[WX] ready to send reply: {}, context: {}".format(reply, context)
- )
+ logger.debug("[WX] ready to send reply: {}, context: {}".format(reply, context))
self._send(reply, context)
def _send(self, reply: Reply, context: Context, retry_cnt=0):
@@ -328,9 +269,7 @@ def func(worker: Future):
try:
worker_exception = worker.exception()
if worker_exception:
- self._fail_callback(
- session_id, exception=worker_exception, **kwargs
- )
+ self._fail_callback(session_id, exception=worker_exception, **kwargs)
else:
self._success_callback(session_id, **kwargs)
except CancelledError as e:
@@ -366,24 +305,14 @@ def consume(self):
if not context_queue.empty():
context = context_queue.get()
logger.debug("[WX] consume context: {}".format(context))
- future: Future = self.handler_pool.submit(
- self._handle, context
- )
- future.add_done_callback(
- self._thread_pool_callback(session_id, context=context)
- )
+ future: Future = self.handler_pool.submit(self._handle, context)
+ future.add_done_callback(self._thread_pool_callback(session_id, context=context))
if session_id not in self.futures:
self.futures[session_id] = []
self.futures[session_id].append(future)
- elif (
- semaphore._initial_value == semaphore._value + 1
- ): # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
- self.futures[session_id] = [
- t for t in self.futures[session_id] if not t.done()
- ]
- assert (
- len(self.futures[session_id]) == 0
- ), "thread pool error"
+ elif semaphore._initial_value == semaphore._value + 1: # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
+ self.futures[session_id] = [t for t in self.futures[session_id] if not t.done()]
+ assert len(self.futures[session_id]) == 0, "thread pool error"
del self.sessions[session_id]
else:
semaphore.release()
@@ -397,9 +326,7 @@ def cancel_session(self, session_id):
future.cancel()
cnt = self.sessions[session_id][0].qsize()
if cnt > 0:
- logger.info(
- "Cancel {} messages in session {}".format(cnt, session_id)
- )
+ logger.info("Cancel {} messages in session {}".format(cnt, session_id))
self.sessions[session_id][0] = Dequeue()
def cancel_all_session(self):
@@ -409,9 +336,7 @@ def cancel_all_session(self):
future.cancel()
cnt = self.sessions[session_id][0].qsize()
if cnt > 0:
- logger.info(
- "Cancel {} messages in session {}".format(cnt, session_id)
- )
+ logger.info("Cancel {} messages in session {}".format(cnt, session_id))
self.sessions[session_id][0] = Dequeue()
diff --git a/channel/terminal/terminal_channel.py b/channel/terminal/terminal_channel.py
index e2060789c..9a413dcff 100644
--- a/channel/terminal/terminal_channel.py
+++ b/channel/terminal/terminal_channel.py
@@ -77,9 +77,7 @@ def startup(self):
if check_prefix(prompt, trigger_prefixs) is None:
prompt = trigger_prefixs[0] + prompt # 给没触发的消息加上触发前缀
- context = self._compose_context(
- ContextType.TEXT, prompt, msg=TerminalMessage(msg_id, prompt)
- )
+ context = self._compose_context(ContextType.TEXT, prompt, msg=TerminalMessage(msg_id, prompt))
if context:
self.produce(context)
else:
diff --git a/channel/wechat/wechat_channel.py b/channel/wechat/wechat_channel.py
index cf200b17e..16d788c44 100644
--- a/channel/wechat/wechat_channel.py
+++ b/channel/wechat/wechat_channel.py
@@ -56,10 +56,7 @@ def wrapper(self, cmsg: ChatMessage):
return
self.receivedMsgs[msgId] = cmsg
create_time = cmsg.create_time # 消息时间戳
- if (
- conf().get("hot_reload") == True
- and int(create_time) < int(time.time()) - 60
- ): # 跳过1分钟前的历史消息
+ if conf().get("hot_reload") == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
logger.debug("[WX]history message {} skipped".format(msgId))
return
return func(self, cmsg)
@@ -88,15 +85,9 @@ def qrCallback(uuid, status, qrcode):
url = f"https://login.weixin.qq.com/l/{uuid}"
qr_api1 = "https://api.isoyu.com/qr/?m=1&e=L&p=20&url={}".format(url)
- qr_api2 = (
- "https://api.qrserver.com/v1/create-qr-code/?size=400×400&data={}".format(
- url
- )
- )
+ qr_api2 = "https://api.qrserver.com/v1/create-qr-code/?size=400×400&data={}".format(url)
qr_api3 = "https://api.pwmqr.com/qrcode/create/?url={}".format(url)
- qr_api4 = "https://my.tv.sohu.com/user/a/wvideo/getQRCode.do?text={}".format(
- url
- )
+ qr_api4 = "https://my.tv.sohu.com/user/a/wvideo/getQRCode.do?text={}".format(url)
print("You can also scan QRCode in any website below:")
print(qr_api3)
print(qr_api4)
@@ -134,18 +125,12 @@ def startup(self):
logger.error("Hot reload failed, try to login without hot reload")
itchat.logout()
os.remove(status_path)
- itchat.auto_login(
- enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback
- )
+ itchat.auto_login(enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback)
else:
raise e
self.user_id = itchat.instance.storageClass.userName
self.name = itchat.instance.storageClass.nickName
- logger.info(
- "Wechat login success, user_id: {}, nickname: {}".format(
- self.user_id, self.name
- )
- )
+ logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name))
# start message listener
itchat.run()
@@ -173,16 +158,10 @@ def handle_single(self, cmsg: ChatMessage):
elif cmsg.ctype == ContextType.PATPAT:
logger.debug("[WX]receive patpat msg: {}".format(cmsg.content))
elif cmsg.ctype == ContextType.TEXT:
- logger.debug(
- "[WX]receive text msg: {}, cmsg={}".format(
- json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg
- )
- )
+ logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
else:
logger.debug("[WX]receive msg: {}, cmsg={}".format(cmsg.content, cmsg))
- context = self._compose_context(
- cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg
- )
+ context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg)
if context:
self.produce(context)
@@ -202,9 +181,7 @@ def handle_group(self, cmsg: ChatMessage):
pass
else:
logger.debug("[WX]receive group msg: {}".format(cmsg.content))
- context = self._compose_context(
- cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg
- )
+ context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg)
if context:
self.produce(context)
diff --git a/channel/wechat/wechat_message.py b/channel/wechat/wechat_message.py
index 18884259c..63c225471 100644
--- a/channel/wechat/wechat_message.py
+++ b/channel/wechat/wechat_message.py
@@ -27,37 +27,23 @@ def __init__(self, itchat_msg, is_group=False):
self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径
self._prepare_fn = lambda: itchat_msg.download(self.content)
elif itchat_msg["Type"] == NOTE and itchat_msg["MsgType"] == 10000:
- if is_group and (
- "加入群聊" in itchat_msg["Content"] or "加入了群聊" in itchat_msg["Content"]
- ):
+ if is_group and ("加入群聊" in itchat_msg["Content"] or "加入了群聊" in itchat_msg["Content"]):
self.ctype = ContextType.JOIN_GROUP
self.content = itchat_msg["Content"]
# 这里只能得到nickname, actual_user_id还是机器人的id
if "加入了群聊" in itchat_msg["Content"]:
- self.actual_user_nickname = re.findall(
- r"\"(.*?)\"", itchat_msg["Content"]
- )[-1]
+ self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[-1]
elif "加入群聊" in itchat_msg["Content"]:
- self.actual_user_nickname = re.findall(
- r"\"(.*?)\"", itchat_msg["Content"]
- )[0]
+ self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
elif "拍了拍我" in itchat_msg["Content"]:
self.ctype = ContextType.PATPAT
self.content = itchat_msg["Content"]
if is_group:
- self.actual_user_nickname = re.findall(
- r"\"(.*?)\"", itchat_msg["Content"]
- )[0]
+ self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
else:
- raise NotImplementedError(
- "Unsupported note message: " + itchat_msg["Content"]
- )
+ raise NotImplementedError("Unsupported note message: " + itchat_msg["Content"])
else:
- raise NotImplementedError(
- "Unsupported message type: Type:{} MsgType:{}".format(
- itchat_msg["Type"], itchat_msg["MsgType"]
- )
- )
+ raise NotImplementedError("Unsupported message type: Type:{} MsgType:{}".format(itchat_msg["Type"], itchat_msg["MsgType"]))
self.from_user_id = itchat_msg["FromUserName"]
self.to_user_id = itchat_msg["ToUserName"]
diff --git a/channel/wechat/wechaty_channel.py b/channel/wechat/wechaty_channel.py
index 7383a206c..051a9cf10 100644
--- a/channel/wechat/wechaty_channel.py
+++ b/channel/wechat/wechaty_channel.py
@@ -60,13 +60,9 @@ def send(self, reply: Reply, context: Context):
receiver_id = context["receiver"]
loop = asyncio.get_event_loop()
if context["isgroup"]:
- receiver = asyncio.run_coroutine_threadsafe(
- self.bot.Room.find(receiver_id), loop
- ).result()
+ receiver = asyncio.run_coroutine_threadsafe(self.bot.Room.find(receiver_id), loop).result()
else:
- receiver = asyncio.run_coroutine_threadsafe(
- self.bot.Contact.find(receiver_id), loop
- ).result()
+ receiver = asyncio.run_coroutine_threadsafe(self.bot.Contact.find(receiver_id), loop).result()
msg = None
if reply.type == ReplyType.TEXT:
msg = reply.content
@@ -83,9 +79,7 @@ def send(self, reply: Reply, context: Context):
voiceLength = int(any_to_sil(file_path, sil_file))
if voiceLength >= 60000:
voiceLength = 60000
- logger.info(
- "[WX] voice too long, length={}, set to 60s".format(voiceLength)
- )
+ logger.info("[WX] voice too long, length={}, set to 60s".format(voiceLength))
# 发送语音
t = int(time.time())
msg = FileBox.from_file(sil_file, name=str(t) + ".sil")
@@ -98,9 +92,7 @@ def send(self, reply: Reply, context: Context):
os.remove(sil_file)
except Exception as e:
pass
- logger.info(
- "[WX] sendVoice={}, receiver={}".format(reply.content, receiver)
- )
+ logger.info("[WX] sendVoice={}, receiver={}".format(reply.content, receiver))
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
img_url = reply.content
t = int(time.time())
@@ -111,9 +103,7 @@ def send(self, reply: Reply, context: Context):
image_storage = reply.content
image_storage.seek(0)
t = int(time.time())
- msg = FileBox.from_base64(
- base64.b64encode(image_storage.read()), str(t) + ".png"
- )
+ msg = FileBox.from_base64(base64.b64encode(image_storage.read()), str(t) + ".png")
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
logger.info("[WX] sendImage, receiver={}".format(receiver))
diff --git a/channel/wechat/wechaty_message.py b/channel/wechat/wechaty_message.py
index f7d27faf8..cdb41ddf2 100644
--- a/channel/wechat/wechaty_message.py
+++ b/channel/wechat/wechaty_message.py
@@ -45,16 +45,12 @@ async def __init__(self, wechaty_msg: Message):
def func():
loop = asyncio.get_event_loop()
- asyncio.run_coroutine_threadsafe(
- voice_file.to_file(self.content), loop
- ).result()
+ asyncio.run_coroutine_threadsafe(voice_file.to_file(self.content), loop).result()
self._prepare_fn = func
else:
- raise NotImplementedError(
- "Unsupported message type: {}".format(wechaty_msg.type())
- )
+ raise NotImplementedError("Unsupported message type: {}".format(wechaty_msg.type()))
from_contact = wechaty_msg.talker() # 获取消息的发送者
self.from_user_id = from_contact.contact_id
@@ -73,9 +69,7 @@ def func():
self.to_user_id = to_contact.contact_id
self.to_user_nickname = to_contact.name
- if (
- self.is_group or wechaty_msg.is_self()
- ): # 如果是群消息,other_user设置为群,如果是私聊消息,而且自己发的,就设置成对方。
+ if self.is_group or wechaty_msg.is_self(): # 如果是群消息,other_user设置为群,如果是私聊消息,而且自己发的,就设置成对方。
self.other_user_id = self.to_user_id
self.other_user_nickname = self.to_user_nickname
else:
@@ -86,7 +80,7 @@ def func():
self.is_at = await wechaty_msg.mention_self()
if not self.is_at: # 有时候复制粘贴的消息,不算做@,但是内容里面会有@xxx,这里做一下兼容
name = wechaty_msg.wechaty.user_self().name
- pattern = f"@{name}(\u2005|\u0020)"
+ pattern = f"@{re.escape(name)}(\u2005|\u0020)"
if re.search(pattern, self.content):
logger.debug(f"wechaty message {self.msg_id} include at")
self.is_at = True
diff --git a/channel/wechatmp/README.md b/channel/wechatmp/README.md
index 219d27675..98ff769c9 100644
--- a/channel/wechatmp/README.md
+++ b/channel/wechatmp/README.md
@@ -1,21 +1,24 @@
# 微信公众号channel
鉴于个人微信号在服务器上通过itchat登录有封号风险,这里新增了微信公众号channel,提供无风险的服务。
-目前支持订阅号和服务号两种类型的公众号。个人主体的微信订阅号由于无法通过微信认证,接口存在限制,目前仅支持最基本的文本交互和语音输入。通过微信认证的订阅号或者服务号可以回复图片和语音。
+
+目前支持订阅号和服务号两种类型的公众号,它们都支持文本交互,语音和图片输入。其中个人主体的微信订阅号由于无法通过微信认证,存在回复时间限制,每天的图片和声音回复次数也有限制。
+
## 使用方法(订阅号,服务号类似)
在开始部署前,你需要一个拥有公网IP的服务器,以提供微信服务器和我们自己服务器的连接。或者你需要进行内网穿透,否则微信服务器无法将消息发送给我们的服务器。
-此外,需要在我们的服务器上安装python的web框架web.py。
+此外,需要在我们的服务器上安装python的web框架web.py和wechatpy。
以ubuntu为例(在ubuntu 22.04上测试):
```
pip3 install web.py
+pip3 install wechatpy
```
然后在[微信公众平台](https://mp.weixin.qq.com)注册一个自己的公众号,类型选择订阅号,主体为个人即可。
-然后根据[接入指南](https://developers.weixin.qq.com/doc/offiaccount/Basic_Information/Access_Overview.html)的说明,在[微信公众平台](https://mp.weixin.qq.com)的“设置与开发”-“基本配置”-“服务器配置”中填写服务器地址`URL`和令牌`Token`。这里的`URL`是`example.com/wx`的形式,不可以使用IP,`Token`是你自己编的一个特定的令牌。消息加解密方式目前选择的是明文模式。
+然后根据[接入指南](https://developers.weixin.qq.com/doc/offiaccount/Basic_Information/Access_Overview.html)的说明,在[微信公众平台](https://mp.weixin.qq.com)的“设置与开发”-“基本配置”-“服务器配置”中填写服务器地址`URL`和令牌`Token`。这里的`URL`是`example.com/wx`的形式,不可以使用IP,`Token`是你自己编的一个特定的令牌。消息加解密方式如果选择了需要加密的模式,需要在配置中填写`wechatmp_aes_key`。
相关的服务器验证代码已经写好,你不需要再添加任何代码。你只需要在本项目根目录的`config.json`中添加
```
@@ -24,6 +27,7 @@ pip3 install web.py
"wechatmp_port": 8080, # 微信公众平台的端口,需要端口转发到80或443
"wechatmp_app_id": "xxxx", # 微信公众平台的appID
"wechatmp_app_secret": "xxxx", # 微信公众平台的appsecret
+"wechatmp_aes_key": "", # 微信公众平台的EncodingAESKey,加密模式需要
"single_chat_prefix": [""], # 推荐设置,任意对话都可以触发回复,不添加前缀
"single_chat_reply_prefix": "", # 推荐设置,回复不设置前缀
"plugin_trigger_prefix": "&", # 推荐设置,在手机微信客户端中,$%^等符号与中文连在一起时会自动显示一段较大的间隔,用户体验不好。请不要使用管理员指令前缀"#",这会造成未知问题。
@@ -40,12 +44,14 @@ sudo iptables-save > /etc/iptables/rules.v4
程序启动并监听端口后,在刚才的“服务器配置”中点击`提交`即可验证你的服务器。
随后在[微信公众平台](https://mp.weixin.qq.com)启用服务器,关闭手动填写规则的自动回复,即可实现ChatGPT的自动回复。
-如果在启用后如果遇到如下报错:
+之后需要在公众号开发信息下将本机IP加入到IP白名单。
+
+不然在启用后,发送语音、图片等消息可能会遇到如下报错:
```
'errcode': 40164, 'errmsg': 'invalid ip xx.xx.xx.xx not in whitelist rid
```
-需要在公众号开发信息下将IP加入到IP白名单。
+
## 个人微信公众号的限制
由于人微信公众号不能通过微信认证,所以没有客服接口,因此公众号无法主动发出消息,只能被动回复。而微信官方对被动回复有5秒的时间限制,最多重试2次,因此最多只有15秒的自动回复时间窗口。因此如果问题比较复杂或者我们的服务器比较忙,ChatGPT的回答就没办法及时回复给用户。为了解决这个问题,这里做了回答缓存,它需要你在回复超时后,再次主动发送任意文字(例如1)来尝试拿到回答缓存。为了优化使用体验,目前设置了两分钟(120秒)的timeout,用户在至多两分钟后即可得到查询到回复或者错误原因。
@@ -91,7 +97,7 @@ python3 -m pip install pyttsx3
## TODO
- [x] 语音输入
- - [ ] 图片输入
+ - [x] 图片输入
- [x] 使用临时素材接口提供认证公众号的图片和语音回复
- [x] 使用永久素材接口提供未认证公众号的图片和语音回复
- [ ] 高并发支持
diff --git a/channel/wechatmp/active_reply.py b/channel/wechatmp/active_reply.py
index d8a8ddee1..37356959b 100644
--- a/channel/wechatmp/active_reply.py
+++ b/channel/wechatmp/active_reply.py
@@ -1,13 +1,16 @@
import time
import web
+from wechatpy import parse_message
+from wechatpy.replies import create_reply
+
-from channel.wechatmp.wechatmp_message import parse_xml
-from channel.wechatmp.passive_reply_message import TextMsg
from bridge.context import *
-from bridge.reply import ReplyType
+from bridge.reply import *
from channel.wechatmp.common import *
from channel.wechatmp.wechatmp_channel import WechatMPChannel
+from channel.wechatmp.wechatmp_message import WeChatMPMessage
+
from common.log import logger
from config import conf
@@ -19,18 +22,25 @@ def GET(self):
def POST(self):
# Make sure to return the instance that first created, @singleton will do that.
- channel = WechatMPChannel()
try:
- webData = web.data()
- # logger.debug("[wechatmp] Receive request:\n" + webData.decode("utf-8"))
- wechatmp_msg = parse_xml(webData)
- if (
- wechatmp_msg.msg_type == "text"
- or wechatmp_msg.msg_type == "voice"
- # or wechatmp_msg.msg_type == "image"
- ):
+ args = web.input()
+ verify_server(args)
+ channel = WechatMPChannel()
+ message = web.data()
+ encrypt_func = lambda x: x
+ if args.get("encrypt_type") == "aes":
+ logger.debug("[wechatmp] Receive encrypted post data:\n" + message.decode("utf-8"))
+ if not channel.crypto:
+ raise Exception("Crypto not initialized, Please set wechatmp_aes_key in config.json")
+ message = channel.crypto.decrypt_message(message, args.msg_signature, args.timestamp, args.nonce)
+ encrypt_func = lambda x: channel.crypto.encrypt_message(x, args.nonce, args.timestamp)
+ else:
+ logger.debug("[wechatmp] Receive post data:\n" + message.decode("utf-8"))
+ msg = parse_message(message)
+ if msg.type in ["text", "voice", "image"]:
+ wechatmp_msg = WeChatMPMessage(msg, client=channel.client)
from_user = wechatmp_msg.from_user_id
- message = wechatmp_msg.content
+ content = wechatmp_msg.content
message_id = wechatmp_msg.msg_id
logger.info(
@@ -39,38 +49,29 @@ def POST(self):
web.ctx.env.get("REMOTE_PORT"),
from_user,
message_id,
- message,
+ content,
)
)
- if (wechatmp_msg.msg_type == "voice" and conf().get("voice_reply_voice") == True):
- rtype = ReplyType.VOICE
+ if msg.type == "voice" and wechatmp_msg.ctype == ContextType.TEXT and conf().get("voice_reply_voice", False):
+ context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, desire_rtype=ReplyType.VOICE, msg=wechatmp_msg)
else:
- rtype = None
- context = channel._compose_context(
- ContextType.TEXT, message, isgroup=False, desire_rtype=rtype, msg=wechatmp_msg
- )
+ context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, msg=wechatmp_msg)
if context:
# set private openai_api_key
# if from_user is not changed in itchat, this can be placed at chat_channel
user_data = conf().get_user_data(from_user)
- context["openai_api_key"] = user_data.get(
- "openai_api_key"
- ) # None or user openai_api_key
+ context["openai_api_key"] = user_data.get("openai_api_key") # None or user openai_api_key
channel.produce(context)
# The reply will be sent by channel.send() in another thread
return "success"
-
- elif wechatmp_msg.msg_type == "event":
- logger.info(
- "[wechatmp] Event {} from {}".format(
- wechatmp_msg.Event, wechatmp_msg.from_user_id
- )
- )
- content = subscribe_msg()
- replyMsg = TextMsg(
- wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content
- )
- return replyMsg.send()
+ elif msg.type == "event":
+ logger.info("[wechatmp] Event {} from {}".format(msg.event, msg.source))
+ if msg.event in ["subscribe", "subscribe_scan"]:
+ reply_text = subscribe_msg()
+ replyPost = create_reply(reply_text, msg)
+ return encrypt_func(replyPost.render())
+ else:
+ return "success"
else:
logger.info("暂且不处理")
return "success"
diff --git a/channel/wechatmp/common.py b/channel/wechatmp/common.py
index 5efccfce1..b6f206c5a 100644
--- a/channel/wechatmp/common.py
+++ b/channel/wechatmp/common.py
@@ -1,6 +1,10 @@
-import hashlib
import textwrap
+import web
+from wechatpy.crypto import WeChatCrypto
+from wechatpy.exceptions import InvalidSignatureException
+from wechatpy.utils import check_signature
+
from config import conf
MAX_UTF8_LEN = 2048
@@ -12,38 +16,28 @@ class WeChatAPIException(Exception):
def verify_server(data):
try:
- if len(data) == 0:
- return "None"
signature = data.signature
timestamp = data.timestamp
nonce = data.nonce
- echostr = data.echostr
+ echostr = data.get("echostr", None)
token = conf().get("wechatmp_token") # 请按照公众平台官网\基本配置中信息填写
-
- data_list = [token, timestamp, nonce]
- data_list.sort()
- sha1 = hashlib.sha1()
- # map(sha1.update, data_list) #python2
- sha1.update("".join(data_list).encode("utf-8"))
- hashcode = sha1.hexdigest()
- print("handle/GET func: hashcode, signature: ", hashcode, signature)
- if hashcode == signature:
- return echostr
- else:
- return ""
- except Exception as Argument:
- return Argument
+ check_signature(token, signature, timestamp, nonce)
+ return echostr
+ except InvalidSignatureException:
+ raise web.Forbidden("Invalid signature")
+ except Exception as e:
+ raise web.Forbidden(str(e))
def subscribe_msg():
- trigger_prefix = conf().get("single_chat_prefix", [""])
+ trigger_prefix = conf().get("single_chat_prefix", [""])[0]
msg = textwrap.dedent(
f"""\
感谢您的关注!
这里是ChatGPT,可以自由对话。
资源有限,回复较慢,请勿着急。
支持语音对话。
- 暂时不支持图片输入。
+ 支持图片输入。
支持图片输出,画字开头的消息将按要求创作图片。
支持tool、角色扮演和文字冒险等丰富的插件。
输入'{trigger_prefix}#帮助' 查看详细指令。"""
@@ -59,7 +53,7 @@ def split_string_by_utf8_length(string, max_length, max_split=0):
if max_split > 0 and len(result) >= max_split:
result.append(encoded[start:].decode("utf-8"))
break
- end = start + max_length
+ end = min(start + max_length, len(encoded))
# 如果当前字节不是 UTF-8 编码的开始字节,则向前查找直到找到开始字节为止
while end < len(encoded) and (encoded[end] & 0b11000000) == 0b10000000:
end -= 1
diff --git a/channel/wechatmp/passive_reply.py b/channel/wechatmp/passive_reply.py
index eca94ba36..6f3fb0a72 100644
--- a/channel/wechatmp/passive_reply.py
+++ b/channel/wechatmp/passive_reply.py
@@ -1,14 +1,16 @@
-import time
+
import asyncio
+import time
import web
+from wechatpy import parse_message
+from wechatpy.replies import ImageReply, VoiceReply, create_reply
-from channel.wechatmp.wechatmp_message import parse_xml
-from channel.wechatmp.passive_reply_message import TextMsg, VoiceMsg, ImageMsg
from bridge.context import *
-from bridge.reply import ReplyType
+from bridge.reply import *
from channel.wechatmp.common import *
from channel.wechatmp.wechatmp_channel import WechatMPChannel
+from channel.wechatmp.wechatmp_message import WeChatMPMessage
from common.log import logger
from config import conf
@@ -20,39 +22,44 @@ def GET(self):
def POST(self):
try:
+ args = web.input()
+ verify_server(args)
request_time = time.time()
channel = WechatMPChannel()
- webData = web.data()
- logger.debug("[wechatmp] Receive post data:\n" + webData.decode("utf-8"))
- wechatmp_msg = parse_xml(webData)
- if wechatmp_msg.msg_type == "text" or wechatmp_msg.msg_type == "voice":
+ message = web.data()
+ encrypt_func = lambda x: x
+ if args.get("encrypt_type") == "aes":
+ logger.debug("[wechatmp] Receive encrypted post data:\n" + message.decode("utf-8"))
+ if not channel.crypto:
+ raise Exception("Crypto not initialized, Please set wechatmp_aes_key in config.json")
+ message = channel.crypto.decrypt_message(message, args.msg_signature, args.timestamp, args.nonce)
+ encrypt_func = lambda x: channel.crypto.encrypt_message(x, args.nonce, args.timestamp)
+ else:
+ logger.debug("[wechatmp] Receive post data:\n" + message.decode("utf-8"))
+ msg = parse_message(message)
+ if msg.type in ["text", "voice", "image"]:
+ wechatmp_msg = WeChatMPMessage(msg, client=channel.client)
from_user = wechatmp_msg.from_user_id
- to_user = wechatmp_msg.to_user_id
- message = wechatmp_msg.content
+ content = wechatmp_msg.content
message_id = wechatmp_msg.msg_id
supported = True
- if "【收到不支持的消息类型,暂无法显示】" in message:
+ if "【收到不支持的消息类型,暂无法显示】" in content:
supported = False # not supported, used to refresh
# New request
if (
from_user not in channel.cache_dict
and from_user not in channel.running
- or message.startswith("#")
- and message_id not in channel.request_cnt # insert the godcmd
+ or content.startswith("#")
+ and message_id not in channel.request_cnt # insert the godcmd
):
# The first query begin
- if (wechatmp_msg.msg_type == "voice" and conf().get("voice_reply_voice") == True):
- rtype = ReplyType.VOICE
+ if msg.type == "voice" and wechatmp_msg.ctype == ContextType.TEXT and conf().get("voice_reply_voice", False):
+ context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, desire_rtype=ReplyType.VOICE, msg=wechatmp_msg)
else:
- rtype = None
- context = channel._compose_context(
- ContextType.TEXT, message, isgroup=False, desire_rtype=rtype, msg=wechatmp_msg
- )
- logger.debug(
- "[wechatmp] context: {} {}".format(context, wechatmp_msg)
- )
+ context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, msg=wechatmp_msg)
+ logger.debug("[wechatmp] context: {} {} {}".format(context, wechatmp_msg, supported))
if supported and context:
# set private openai_api_key
@@ -62,43 +69,38 @@ def POST(self):
channel.running.add(from_user)
channel.produce(context)
else:
- trigger_prefix = conf().get("single_chat_prefix", [""])
+ trigger_prefix = conf().get("single_chat_prefix", [""])[0]
if trigger_prefix or not supported:
if trigger_prefix:
- content = textwrap.dedent(
+ reply_text = textwrap.dedent(
f"""\
请输入'{trigger_prefix}'接你想说的话跟我说话。
例如:
{trigger_prefix}你好,很高兴见到你。"""
)
else:
- content = textwrap.dedent(
+ reply_text = textwrap.dedent(
"""\
你好,很高兴见到你。
请跟我说话吧。"""
)
else:
logger.error(f"[wechatmp] unknown error")
- content = textwrap.dedent(
+ reply_text = textwrap.dedent(
"""\
未知错误,请稍后再试"""
)
- replyPost = TextMsg(wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content).send()
- return replyPost
+ replyPost = create_reply(reply_text, msg)
+ return encrypt_func(replyPost.render())
# Wechat official server will request 3 times (5 seconds each), with the same message_id.
# Because the interval is 5 seconds, here assumed that do not have multithreading problems.
request_cnt = channel.request_cnt.get(message_id, 0) + 1
channel.request_cnt[message_id] = request_cnt
logger.info(
- "[wechatmp] Request {} from {} {}\n{}\n{}:{}".format(
- request_cnt,
- from_user,
- message_id,
- message,
- web.ctx.env.get("REMOTE_ADDR"),
- web.ctx.env.get("REMOTE_PORT"),
+ "[wechatmp] Request {} from {} {} {}:{}\n{}".format(
+ request_cnt, from_user, message_id, web.ctx.env.get("REMOTE_ADDR"), web.ctx.env.get("REMOTE_PORT"), content
)
)
@@ -118,76 +120,91 @@ def POST(self):
time.sleep(2)
# and do nothing, waiting for the next request
return "success"
- else: # request_cnt == 3:
+ else: # request_cnt == 3:
# return timeout message
reply_text = "【正在思考中,回复任意文字尝试获取回复】"
- replyPost = TextMsg(from_user, to_user, reply_text).send()
- return replyPost
+ replyPost = create_reply(reply_text, msg)
+ return encrypt_func(replyPost.render())
# reply is ready
channel.request_cnt.pop(message_id)
# no return because of bandwords or other reasons
- if (
- from_user not in channel.cache_dict
- and from_user not in channel.running
- ):
+ if from_user not in channel.cache_dict and from_user not in channel.running:
return "success"
# Only one request can access to the cached data
try:
- (reply_type, content) = channel.cache_dict.pop(from_user)
+ (reply_type, reply_content) = channel.cache_dict.pop(from_user)
except KeyError:
return "success"
- if (reply_type == "text"):
- if len(content.encode("utf8")) <= MAX_UTF8_LEN:
- reply_text = content
+ if reply_type == "text":
+ if len(reply_content.encode("utf8")) <= MAX_UTF8_LEN:
+ reply_text = reply_content
else:
continue_text = "\n【未完待续,回复任意文字以继续】"
splits = split_string_by_utf8_length(
- content,
+ reply_content,
MAX_UTF8_LEN - len(continue_text.encode("utf-8")),
max_split=1,
)
reply_text = splits[0] + continue_text
channel.cache_dict[from_user] = ("text", splits[1])
-
+
logger.info(
"[wechatmp] Request {} do send to {} {}: {}\n{}".format(
request_cnt,
from_user,
message_id,
- message,
+ content,
reply_text,
)
)
- replyPost = TextMsg(from_user, to_user, reply_text).send()
- return replyPost
+ replyPost = create_reply(reply_text, msg)
+ return encrypt_func(replyPost.render())
- elif (reply_type == "voice"):
- media_id = content
+ elif reply_type == "voice":
+ media_id = reply_content
asyncio.run_coroutine_threadsafe(channel.delete_media(media_id), channel.delete_media_loop)
- replyPost = VoiceMsg(from_user, to_user, media_id).send()
- return replyPost
+ logger.info(
+ "[wechatmp] Request {} do send to {} {}: {} voice media_id {}".format(
+ request_cnt,
+ from_user,
+ message_id,
+ content,
+ media_id,
+ )
+ )
+ replyPost = VoiceReply(message=msg)
+ replyPost.media_id = media_id
+ return encrypt_func(replyPost.render())
- elif (reply_type == "image"):
- media_id = content
+ elif reply_type == "image":
+ media_id = reply_content
asyncio.run_coroutine_threadsafe(channel.delete_media(media_id), channel.delete_media_loop)
- replyPost = ImageMsg(from_user, to_user, media_id).send()
- return replyPost
-
- elif wechatmp_msg.msg_type == "event":
- logger.info(
- "[wechatmp] Event {} from {}".format(
- wechatmp_msg.content, wechatmp_msg.from_user_id
+ logger.info(
+ "[wechatmp] Request {} do send to {} {}: {} image media_id {}".format(
+ request_cnt,
+ from_user,
+ message_id,
+ content,
+ media_id,
+ )
)
- )
- content = subscribe_msg()
- replyMsg = TextMsg(
- wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content
- )
- return replyMsg.send()
+ replyPost = ImageReply(message=msg)
+ replyPost.media_id = media_id
+ return encrypt_func(replyPost.render())
+
+ elif msg.type == "event":
+ logger.info("[wechatmp] Event {} from {}".format(msg.event, msg.source))
+ if msg.event in ["subscribe", "subscribe_scan"]:
+ reply_text = subscribe_msg()
+ replyPost = create_reply(reply_text, msg)
+ return encrypt_func(replyPost.render())
+ else:
+ return "success"
+
else:
logger.info("暂且不处理")
return "success"
diff --git a/channel/wechatmp/passive_reply_message.py b/channel/wechatmp/passive_reply_message.py
deleted file mode 100644
index ef58d7093..000000000
--- a/channel/wechatmp/passive_reply_message.py
+++ /dev/null
@@ -1,78 +0,0 @@
-# -*- coding: utf-8 -*-#
-# filename: reply.py
-import time
-
-
-class Msg(object):
- def __init__(self):
- pass
-
- def send(self):
- return "success"
-
-
-class TextMsg(Msg):
- def __init__(self, toUserName, fromUserName, content):
- self.__dict = dict()
- self.__dict["ToUserName"] = toUserName
- self.__dict["FromUserName"] = fromUserName
- self.__dict["CreateTime"] = int(time.time())
- self.__dict["Content"] = content
-
- def send(self):
- XmlForm = """
-
-
-
- {CreateTime}
-
-
-
- """
- return XmlForm.format(**self.__dict)
-
-
-class VoiceMsg(Msg):
- def __init__(self, toUserName, fromUserName, mediaId):
- self.__dict = dict()
- self.__dict["ToUserName"] = toUserName
- self.__dict["FromUserName"] = fromUserName
- self.__dict["CreateTime"] = int(time.time())
- self.__dict["MediaId"] = mediaId
-
- def send(self):
- XmlForm = """
-
-
-
- {CreateTime}
-
-
-
-
-
- """
- return XmlForm.format(**self.__dict)
-
-
-class ImageMsg(Msg):
- def __init__(self, toUserName, fromUserName, mediaId):
- self.__dict = dict()
- self.__dict["ToUserName"] = toUserName
- self.__dict["FromUserName"] = fromUserName
- self.__dict["CreateTime"] = int(time.time())
- self.__dict["MediaId"] = mediaId
-
- def send(self):
- XmlForm = """
-
-
-
- {CreateTime}
-
-
-
-
-
- """
- return XmlForm.format(**self.__dict)
diff --git a/channel/wechatmp/wechatmp_channel.py b/channel/wechatmp/wechatmp_channel.py
index 9780048b9..aa1fc74d7 100644
--- a/channel/wechatmp/wechatmp_channel.py
+++ b/channel/wechatmp/wechatmp_channel.py
@@ -1,22 +1,26 @@
# -*- coding: utf-8 -*-
+import asyncio
+import imghdr
import io
import os
+import threading
import time
-import imghdr
+
import requests
+import web
+from wechatpy.crypto import WeChatCrypto
+from wechatpy.exceptions import WeChatClientException
+
from bridge.context import *
from bridge.reply import *
from channel.chat_channel import ChatChannel
-from channel.wechatmp.wechatmp_client import WechatMPClient
from channel.wechatmp.common import *
+from channel.wechatmp.wechatmp_client import WechatMPClient
from common.log import logger
from common.singleton import singleton
from config import conf
+from voice.audio_convert import any_to_mp3
-import asyncio
-from threading import Thread
-
-import web
# If using SSL, uncomment the following lines, and modify the certificate path.
# from cheroot.server import HTTPServer
# from cheroot.ssl.builtin import BuiltinSSLAdapter
@@ -31,7 +35,14 @@ def __init__(self, passive_reply=True):
super().__init__()
self.passive_reply = passive_reply
self.NOT_SUPPORT_REPLYTYPE = []
- self.client = WechatMPClient()
+ appid = conf().get("wechatmp_app_id")
+ secret = conf().get("wechatmp_app_secret")
+ token = conf().get("wechatmp_token")
+ aes_key = conf().get("wechatmp_aes_key")
+ self.client = WechatMPClient(appid, secret)
+ self.crypto = None
+ if aes_key:
+ self.crypto = WeChatCrypto(token, aes_key, appid)
if self.passive_reply:
# Cache the reply to the user's first message
self.cache_dict = dict()
@@ -41,11 +52,10 @@ def __init__(self, passive_reply=True):
self.request_cnt = dict()
# The permanent media need to be deleted to avoid media number limit
self.delete_media_loop = asyncio.new_event_loop()
- t = Thread(target=self.start_loop, args=(self.delete_media_loop,))
+ t = threading.Thread(target=self.start_loop, args=(self.delete_media_loop,))
t.setDaemon(True)
t.start()
-
def startup(self):
if self.passive_reply:
urls = ("/wx", "channel.wechatmp.passive_reply.Query")
@@ -62,7 +72,7 @@ def start_loop(self, loop):
async def delete_media(self, media_id):
logger.debug("[wechatmp] permanent media {} will be deleted in 10s".format(media_id))
await asyncio.sleep(10)
- self.client.delete_permanent_media(media_id)
+ self.client.material.delete(media_id)
logger.info("[wechatmp] permanent media {} has been deleted".format(media_id))
def send(self, reply: Reply, context: Context):
@@ -70,97 +80,134 @@ def send(self, reply: Reply, context: Context):
if self.passive_reply:
if reply.type == ReplyType.TEXT or reply.type == ReplyType.INFO or reply.type == ReplyType.ERROR:
reply_text = reply.content
- logger.info("[wechatmp] reply to {} cached:\n{}".format(receiver, reply_text))
+ logger.info("[wechatmp] text cached, receiver {}\n{}".format(receiver, reply_text))
self.cache_dict[receiver] = ("text", reply_text)
elif reply.type == ReplyType.VOICE:
- voice_file_path = reply.content
- logger.info("[wechatmp] voice file path {}".format(voice_file_path))
- with open(voice_file_path, 'rb') as f:
- filename = receiver + "-" + context["msg"].msg_id + ".mp3"
- media_id = self.client.upload_permanent_media("voice", (filename, f, "audio/mpeg"))
- # 根据文件大小估计一个微信自动审核的时间,审核结束前返回将会导致语音无法播放,这个估计有待验证
- f_size = os.fstat(f.fileno()).st_size
- print(f_size)
- time.sleep(1.0 + 2 * f_size / 1024 / 1024)
- logger.info("[wechatmp] voice reply to {} uploaded: {}".format(receiver, media_id))
- self.cache_dict[receiver] = ("voice", media_id)
+ try:
+ voice_file_path = reply.content
+ with open(voice_file_path, "rb") as f:
+ # support: <2M, <60s, mp3/wma/wav/amr
+ response = self.client.material.add("voice", f)
+ logger.debug("[wechatmp] upload voice response: {}".format(response))
+ # 根据文件大小估计一个微信自动审核的时间,审核结束前返回将会导致语音无法播放,这个估计有待验证
+ f_size = os.fstat(f.fileno()).st_size
+ time.sleep(1.0 + 2 * f_size / 1024 / 1024)
+ # todo check media_id
+ except WeChatClientException as e:
+ logger.error("[wechatmp] upload voice failed: {}".format(e))
+ return
+ media_id = response["media_id"]
+ logger.info("[wechatmp] voice uploaded, receiver {}, media_id {}".format(receiver, media_id))
+ self.cache_dict[receiver] = ("voice", media_id)
+
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
img_url = reply.content
pic_res = requests.get(img_url, stream=True)
- print(pic_res.headers)
image_storage = io.BytesIO()
for block in pic_res.iter_content(1024):
image_storage.write(block)
image_storage.seek(0)
image_type = imghdr.what(image_storage)
- filename = receiver + "-" + context["msg"].msg_id + "." + image_type
+ filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
content_type = "image/" + image_type
- media_id = self.client.upload_permanent_media("image", (filename, image_storage, content_type))
- logger.info("[wechatmp] image reply to {} uploaded: {}".format(receiver, media_id))
+ try:
+ response = self.client.material.add("image", (filename, image_storage, content_type))
+ logger.debug("[wechatmp] upload image response: {}".format(response))
+ except WeChatClientException as e:
+ logger.error("[wechatmp] upload image failed: {}".format(e))
+ return
+ media_id = response["media_id"]
+ logger.info("[wechatmp] image uploaded, receiver {}, media_id {}".format(receiver, media_id))
self.cache_dict[receiver] = ("image", media_id)
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
image_storage = reply.content
image_storage.seek(0)
image_type = imghdr.what(image_storage)
- filename = receiver + "-" + context["msg"].msg_id + "." + image_type
+ filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
content_type = "image/" + image_type
- media_id = self.client.upload_permanent_media("image", (filename, image_storage, content_type))
- logger.info("[wechatmp] image reply to {} uploaded: {}".format(receiver, media_id))
+ try:
+ response = self.client.material.add("image", (filename, image_storage, content_type))
+ logger.debug("[wechatmp] upload image response: {}".format(response))
+ except WeChatClientException as e:
+ logger.error("[wechatmp] upload image failed: {}".format(e))
+ return
+ media_id = response["media_id"]
+ logger.info("[wechatmp] image uploaded, receiver {}, media_id {}".format(receiver, media_id))
self.cache_dict[receiver] = ("image", media_id)
else:
if reply.type == ReplyType.TEXT or reply.type == ReplyType.INFO or reply.type == ReplyType.ERROR:
reply_text = reply.content
- self.client.send_text(receiver, reply_text)
- logger.info("[wechatmp] Do send to {}: {}".format(receiver, reply_text))
+ texts = split_string_by_utf8_length(reply_text, MAX_UTF8_LEN)
+ if len(texts) > 1:
+ logger.info("[wechatmp] text too long, split into {} parts".format(len(texts)))
+ for text in texts:
+ self.client.message.send_text(receiver, text)
+ logger.info("[wechatmp] Do send text to {}: {}".format(receiver, reply_text))
elif reply.type == ReplyType.VOICE:
- voice_file_path = reply.content
- logger.info("[wechatmp] voice file path {}".format(voice_file_path))
- with open(voice_file_path, 'rb') as f:
- filename = receiver + "-" + context["msg"].msg_id + ".mp3"
- media_id = self.client.upload_media("voice", (filename, f, "audio/mpeg"))
- self.client.send_voice(receiver, media_id)
- logger.info("[wechatmp] Do send voice to {}".format(receiver))
+ try:
+ file_path = reply.content
+ file_name = os.path.basename(file_path)
+ file_type = os.path.splitext(file_name)[1]
+ if file_type == ".mp3":
+ file_type = "audio/mpeg"
+ elif file_type == ".amr":
+ file_type = "audio/amr"
+ else:
+ mp3_file = os.path.splitext(file_path)[0] + ".mp3"
+ any_to_mp3(file_path, mp3_file)
+ file_path = mp3_file
+ file_name = os.path.basename(file_path)
+ file_type = "audio/mpeg"
+ logger.info("[wechatmp] file_name: {}, file_type: {} ".format(file_name, file_type))
+ # support: <2M, <60s, AMR\MP3
+ response = self.client.media.upload("voice", (file_name, open(file_path, "rb"), file_type))
+ logger.debug("[wechatmp] upload voice response: {}".format(response))
+ except WeChatClientException as e:
+ logger.error("[wechatmp] upload voice failed: {}".format(e))
+ return
+ self.client.message.send_voice(receiver, response["media_id"])
+ logger.info("[wechatmp] Do send voice to {}".format(receiver))
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
img_url = reply.content
pic_res = requests.get(img_url, stream=True)
- print(pic_res.headers)
image_storage = io.BytesIO()
for block in pic_res.iter_content(1024):
image_storage.write(block)
image_storage.seek(0)
image_type = imghdr.what(image_storage)
- filename = receiver + "-" + context["msg"].msg_id + "." + image_type
+ filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
content_type = "image/" + image_type
- # content_type = pic_res.headers.get('content-type')
- media_id = self.client.upload_media("image", (filename, image_storage, content_type))
- self.client.send_image(receiver, media_id)
- logger.info("[wechatmp] sendImage url={}, receiver={}".format(img_url, receiver))
+ try:
+ response = self.client.media.upload("image", (filename, image_storage, content_type))
+ logger.debug("[wechatmp] upload image response: {}".format(response))
+ except WeChatClientException as e:
+ logger.error("[wechatmp] upload image failed: {}".format(e))
+ return
+ self.client.message.send_image(receiver, response["media_id"])
+ logger.info("[wechatmp] Do send image to {}".format(receiver))
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
image_storage = reply.content
image_storage.seek(0)
image_type = imghdr.what(image_storage)
- filename = receiver + "-" + context["msg"].msg_id + "." + image_type
+ filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
content_type = "image/" + image_type
- media_id = self.client.upload_media("image", (filename, image_storage, content_type))
- self.client.send_image(receiver, media_id)
- logger.info("[wechatmp] sendImage, receiver={}".format(receiver))
+ try:
+ response = self.client.media.upload("image", (filename, image_storage, content_type))
+ logger.debug("[wechatmp] upload image response: {}".format(response))
+ except WeChatClientException as e:
+ logger.error("[wechatmp] upload image failed: {}".format(e))
+ return
+ self.client.message.send_image(receiver, response["media_id"])
+ logger.info("[wechatmp] Do send image to {}".format(receiver))
return
def _success_callback(self, session_id, context, **kwargs): # 线程异常结束时的回调函数
- logger.debug(
- "[wechatmp] Success to generate reply, msgId={}".format(
- context["msg"].msg_id
- )
- )
+ logger.debug("[wechatmp] Success to generate reply, msgId={}".format(context["msg"].msg_id))
if self.passive_reply:
self.running.remove(session_id)
def _fail_callback(self, session_id, exception, context, **kwargs): # 线程异常结束时的回调函数
- logger.exception(
- "[wechatmp] Fail to generate reply to user, msgId={}, exception={}".format(
- context["msg"].msg_id, exception
- )
- )
+ logger.exception("[wechatmp] Fail to generate reply to user, msgId={}, exception={}".format(context["msg"].msg_id, exception))
if self.passive_reply:
assert session_id not in self.cache_dict
self.running.remove(session_id)
diff --git a/channel/wechatmp/wechatmp_client.py b/channel/wechatmp/wechatmp_client.py
index 96ebddb74..19dca3219 100644
--- a/channel/wechatmp/wechatmp_client.py
+++ b/channel/wechatmp/wechatmp_client.py
@@ -1,180 +1,49 @@
-import time
-import json
-import requests
import threading
-from channel.wechatmp.common import *
-from common.log import logger
-from config import conf
-
-
-class WechatMPClient:
- def __init__(self):
- self.app_id = conf().get("wechatmp_app_id")
- self.app_secret = conf().get("wechatmp_app_secret")
- self.access_token = None
- self.access_token_expires_time = 0
- self.access_token_lock = threading.Lock()
- self.get_access_token()
-
-
- def wechatmp_request(self, method, url, **kwargs):
- r = requests.request(method=method, url=url, **kwargs)
- r.raise_for_status()
- r.encoding = "utf-8"
- ret = r.json()
- if "errcode" in ret and ret["errcode"] != 0:
- if ret["errcode"] == 45009:
- self.clear_quota_v2()
- raise WeChatAPIException("{}".format(ret))
- return ret
-
- def get_access_token(self):
- # return the access_token
- if self.access_token:
- if self.access_token_expires_time - time.time() > 60:
- return self.access_token
-
- # Get new access_token
- # Do not request access_token in parallel! Only the last obtained is valid.
- if self.access_token_lock.acquire(blocking=False):
- # Wait for other threads that have previously obtained access_token to complete the request
- # This happens every 2 hours, so it doesn't affect the experience very much
- time.sleep(1)
- self.access_token = None
- url = "https://api.weixin.qq.com/cgi-bin/token"
- params = {
- "grant_type": "client_credential",
- "appid": self.app_id,
- "secret": self.app_secret,
- }
- ret = self.wechatmp_request(method="get", url=url, params=params)
- self.access_token = ret["access_token"]
- self.access_token_expires_time = int(time.time()) + ret["expires_in"]
- logger.info("[wechatmp] access_token: {}".format(self.access_token))
- self.access_token_lock.release()
- else:
- # Wait for token update
- while self.access_token_lock.locked():
- time.sleep(0.1)
- return self.access_token
-
-
- def send_text(self, receiver, reply_text):
- url = "https://api.weixin.qq.com/cgi-bin/message/custom/send"
- params = {"access_token": self.get_access_token()}
- json_data = {
- "touser": receiver,
- "msgtype": "text",
- "text": {"content": reply_text},
- }
- self.wechatmp_request(
- method="post",
- url=url,
- params=params,
- data=json.dumps(json_data, ensure_ascii=False).encode("utf8"),
- )
-
-
- def send_voice(self, receiver, media_id):
- url="https://api.weixin.qq.com/cgi-bin/message/custom/send"
- params = {"access_token": self.get_access_token()}
- json_data = {
- "touser": receiver,
- "msgtype": "voice",
- "voice": {
- "media_id": media_id
- }
- }
- self.wechatmp_request(
- method="post",
- url=url,
- params=params,
- data=json.dumps(json_data, ensure_ascii=False).encode("utf8"),
- )
-
- def send_image(self, receiver, media_id):
- url="https://api.weixin.qq.com/cgi-bin/message/custom/send"
- params = {"access_token": self.get_access_token()}
- json_data = {
- "touser": receiver,
- "msgtype": "image",
- "image": {
- "media_id": media_id
- }
- }
- self.wechatmp_request(
- method="post",
- url=url,
- params=params,
- data=json.dumps(json_data, ensure_ascii=False).encode("utf8"),
- )
-
-
- def upload_media(self, media_type, media_file):
- url="https://api.weixin.qq.com/cgi-bin/media/upload"
- params={
- "access_token": self.get_access_token(),
- "type": media_type
- }
- files={"media": media_file}
- ret = self.wechatmp_request(
- method="post",
- url=url,
- params=params,
- files=files
- )
- logger.debug("[wechatmp] media {} uploaded".format(media_file))
- return ret["media_id"]
+import time
+from wechatpy.client import WeChatClient
+from wechatpy.exceptions import APILimitedException
- def upload_permanent_media(self, media_type, media_file):
- url="https://api.weixin.qq.com/cgi-bin/material/add_material"
- params={
- "access_token": self.get_access_token(),
- "type": media_type
- }
- files={"media": media_file}
- ret = self.wechatmp_request(
- method="post",
- url=url,
- params=params,
- files=files
- )
- logger.debug("[wechatmp] permanent media {} uploaded".format(media_file))
- return ret["media_id"]
+from channel.wechatmp.common import *
+from common.log import logger
- def delete_permanent_media(self, media_id):
- url="https://api.weixin.qq.com/cgi-bin/material/del_material"
- params={
- "access_token": self.get_access_token()
- }
- self.wechatmp_request(
- method="post",
- url=url,
- params=params,
- data=json.dumps({"media_id": media_id}, ensure_ascii=False).encode("utf8")
- )
- logger.debug("[wechatmp] permanent media {} deleted".format(media_id))
+class WechatMPClient(WeChatClient):
+ def __init__(self, appid, secret, access_token=None, session=None, timeout=None, auto_retry=True):
+ super(WechatMPClient, self).__init__(appid, secret, access_token, session, timeout, auto_retry)
+ self.fetch_access_token_lock = threading.Lock()
+ self.clear_quota_lock = threading.Lock()
+ self.last_clear_quota_time = -1
def clear_quota(self):
- url="https://api.weixin.qq.com/cgi-bin/clear_quota"
- params = {
- "access_token": self.get_access_token()
- }
- self.wechatmp_request(
- method="post",
- url=url,
- params=params,
- data={"appid": self.app_id}
- )
- logger.debug("[wechatmp] API quata has been cleard")
+ return self.post("clear_quota", data={"appid": self.appid})
def clear_quota_v2(self):
- url="https://api.weixin.qq.com/cgi-bin/clear_quota/v2"
- self.wechatmp_request(
- method="post",
- url=url,
- data={"appid": self.app_id, "appsecret": self.app_secret}
- )
- logger.debug("[wechatmp] API quata has been cleard")
+ return self.post("clear_quota/v2", params={"appid": self.appid, "appsecret": self.secret})
+
+ def fetch_access_token(self): # 重载父类方法,加锁避免多线程重复获取access_token
+ with self.fetch_access_token_lock:
+ access_token = self.session.get(self.access_token_key)
+ if access_token:
+ if not self.expires_at:
+ return access_token
+ timestamp = time.time()
+ if self.expires_at - timestamp > 60:
+ return access_token
+ return super().fetch_access_token()
+
+ def _request(self, method, url_or_endpoint, **kwargs): # 重载父类方法,遇到API限流时,清除quota后重试
+ try:
+ return super()._request(method, url_or_endpoint, **kwargs)
+ except APILimitedException as e:
+ logger.error("[wechatmp] API quata has been used up. {}".format(e))
+ if self.last_clear_quota_time == -1 or time.time() - self.last_clear_quota_time > 60:
+ with self.clear_quota_lock:
+ if self.last_clear_quota_time == -1 or time.time() - self.last_clear_quota_time > 60:
+ self.last_clear_quota_time = time.time()
+ response = self.clear_quota_v2()
+ logger.debug("[wechatmp] API quata has been cleard, {}".format(response))
+ return super()._request(method, url_or_endpoint, **kwargs)
+ else:
+ logger.error("[wechatmp] last clear quota time is {}, less than 60s, skip clear quota")
+ raise e
diff --git a/channel/wechatmp/wechatmp_message.py b/channel/wechatmp/wechatmp_message.py
index d385897c3..27c9fbb85 100644
--- a/channel/wechatmp/wechatmp_message.py
+++ b/channel/wechatmp/wechatmp_message.py
@@ -1,50 +1,56 @@
# -*- coding: utf-8 -*-#
-# filename: receive.py
-import xml.etree.ElementTree as ET
from bridge.context import ContextType
from channel.chat_message import ChatMessage
from common.log import logger
-
-
-def parse_xml(web_data):
- if len(web_data) == 0:
- return None
- xmlData = ET.fromstring(web_data)
- return WeChatMPMessage(xmlData)
+from common.tmp_dir import TmpDir
class WeChatMPMessage(ChatMessage):
- def __init__(self, xmlData):
- super().__init__(xmlData)
- self.to_user_id = xmlData.find("ToUserName").text
- self.from_user_id = xmlData.find("FromUserName").text
- self.create_time = xmlData.find("CreateTime").text
- self.msg_type = xmlData.find("MsgType").text
- try:
- self.msg_id = xmlData.find("MsgId").text
- except:
- self.msg_id = self.from_user_id + self.create_time
+ def __init__(self, msg, client=None):
+ super().__init__(msg)
+ self.msg_id = msg.id
+ self.create_time = msg.time
self.is_group = False
- # reply to other_user_id
- self.other_user_id = self.from_user_id
-
- if self.msg_type == "text":
+ if msg.type == "text":
self.ctype = ContextType.TEXT
- self.content = xmlData.find("Content").text
- elif self.msg_type == "voice":
- self.ctype = ContextType.TEXT
- self.content = xmlData.find("Recognition").text # 接收语音识别结果
- # other voice_to_text method not implemented yet
- if self.content == None:
- self.content = "你好"
- elif self.msg_type == "image":
- # not implemented yet
- self.pic_url = xmlData.find("PicUrl").text
- self.media_id = xmlData.find("MediaId").text
- elif self.msg_type == "event":
- self.content = xmlData.find("Event").text
- else: # video, shortvideo, location, link
- # not implemented
- pass
+ self.content = msg.content
+ elif msg.type == "voice":
+ if msg.recognition == None:
+ self.ctype = ContextType.VOICE
+ self.content = TmpDir().path() + msg.media_id + "." + msg.format # content直接存临时目录路径
+
+ def download_voice():
+ # 如果响应状态码是200,则将响应内容写入本地文件
+ response = client.media.download(msg.media_id)
+ if response.status_code == 200:
+ with open(self.content, "wb") as f:
+ f.write(response.content)
+ else:
+ logger.info(f"[wechatmp] Failed to download voice file, {response.content}")
+
+ self._prepare_fn = download_voice
+ else:
+ self.ctype = ContextType.TEXT
+ self.content = msg.recognition
+ elif msg.type == "image":
+ self.ctype = ContextType.IMAGE
+ self.content = TmpDir().path() + msg.media_id + ".png" # content直接存临时目录路径
+
+ def download_image():
+ # 如果响应状态码是200,则将响应内容写入本地文件
+ response = client.media.download(msg.media_id)
+ if response.status_code == 200:
+ with open(self.content, "wb") as f:
+ f.write(response.content)
+ else:
+ logger.info(f"[wechatmp] Failed to download image file, {response.content}")
+
+ self._prepare_fn = download_image
+ else:
+ raise NotImplementedError("Unsupported message type: Type:{} ".format(msg.type))
+
+ self.from_user_id = msg.source
+ self.to_user_id = msg.target
+ self.other_user_id = msg.source
diff --git a/common/time_check.py b/common/time_check.py
index 808f71ab3..5c2dacba6 100644
--- a/common/time_check.py
+++ b/common/time_check.py
@@ -13,23 +13,15 @@ def _time_checker(self, *args, **kwargs):
if chat_time_module:
chat_start_time = _config.get("chat_start_time", "00:00")
chat_stopt_time = _config.get("chat_stop_time", "24:00")
- time_regex = re.compile(
- r"^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$"
- ) # 时间匹配,包含24:00
+ time_regex = re.compile(r"^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$") # 时间匹配,包含24:00
starttime_format_check = time_regex.match(chat_start_time) # 检查停止时间格式
stoptime_format_check = time_regex.match(chat_stopt_time) # 检查停止时间格式
chat_time_check = chat_start_time < chat_stopt_time # 确定启动时间<停止时间
# 时间格式检查
- if not (
- starttime_format_check and stoptime_format_check and chat_time_check
- ):
- logger.warn(
- "时间格式不正确,请在config.json中修改您的CHAT_START_TIME/CHAT_STOP_TIME,否则可能会影响您正常使用,开始({})-结束({})".format(
- starttime_format_check, stoptime_format_check
- )
- )
+ if not (starttime_format_check and stoptime_format_check and chat_time_check):
+ logger.warn("时间格式不正确,请在config.json中修改您的CHAT_START_TIME/CHAT_STOP_TIME,否则可能会影响您正常使用,开始({})-结束({})".format(starttime_format_check, stoptime_format_check))
if chat_start_time > "23:59":
logger.error("启动时间可能存在问题,请修改!")
diff --git a/config.py b/config.py
index 2c6739b64..0ba5aa64c 100644
--- a/config.py
+++ b/config.py
@@ -68,6 +68,11 @@
"chat_time_module": False, # 是否开启服务时间限制
"chat_start_time": "00:00", # 服务开始时间
"chat_stop_time": "24:00", # 服务结束时间
+ # 翻译api
+ "translate": "baidu", # 翻译api,支持baidu
+ # baidu翻译api的配置
+ "baidu_translate_app_id": "", # 百度翻译api的appid
+ "baidu_translate_app_key": "", # 百度翻译api的秘钥
# itchat的配置
"hot_reload": False, # 是否开启热重载
# wechaty的配置
@@ -75,8 +80,9 @@
# wechatmp的配置
"wechatmp_token": "", # 微信公众平台的Token
"wechatmp_port": 8080, # 微信公众平台的端口,需要端口转发到80或443
- "wechatmp_app_id": "", # 微信公众平台的appID,仅服务号需要
- "wechatmp_app_secret": "", # 微信公众平台的appsecret,仅服务号需要
+ "wechatmp_app_id": "", # 微信公众平台的appID
+ "wechatmp_app_secret": "", # 微信公众平台的appsecret
+ "wechatmp_aes_key": "", # 微信公众平台的EncodingAESKey,加密模式需要
# chatgpt指令自定义触发词
"clear_memory_commands": ["#清除记忆"], # 重置会话指令,必须以#开头
# channel配置
@@ -159,9 +165,7 @@ def load_config():
for name, value in os.environ.items():
name = name.lower()
if name in available_setting:
- logger.info(
- "[INIT] override config by environ args: {}={}".format(name, value)
- )
+ logger.info("[INIT] override config by environ args: {}={}".format(name, value))
try:
config[name] = eval(value)
except:
diff --git a/docker/Dockerfile.alpine b/docker/Dockerfile.alpine
index 324a76ebf..61f80c28f 100644
--- a/docker/Dockerfile.alpine
+++ b/docker/Dockerfile.alpine
@@ -22,8 +22,8 @@ RUN apk add --no-cache \
&& cd ${BUILD_PREFIX} \
&& cp config-template.json ${BUILD_PREFIX}/config.json \
&& /usr/local/bin/python -m pip install --no-cache --upgrade pip \
- && pip install --no-cache -r requirements.txt \
- && pip install --no-cache -r requirements-optional.txt \
+ && pip install --no-cache -r requirements.txt --extra-index-url https://alpine-wheels.github.io/index\
+ && pip install --no-cache -r requirements-optional.txt --extra-index-url https://alpine-wheels.github.io/index\
&& apk del curl wget
WORKDIR ${BUILD_PREFIX}
diff --git a/docker/Dockerfile.latest b/docker/Dockerfile.latest
index c9a5a55e6..432075cd1 100644
--- a/docker/Dockerfile.latest
+++ b/docker/Dockerfile.latest
@@ -13,8 +13,8 @@ RUN apk add --no-cache bash ffmpeg espeak \
&& cd ${BUILD_PREFIX} \
&& cp config-template.json config.json \
&& /usr/local/bin/python -m pip install --no-cache --upgrade pip \
- && pip install --no-cache -r requirements.txt \
- && pip install --no-cache -r requirements-optional.txt
+ && pip install --no-cache -r requirements.txt --extra-index-url https://alpine-wheels.github.io/index\
+ && pip install --no-cache -r requirements-optional.txt --extra-index-url https://alpine-wheels.github.io/index
WORKDIR ${BUILD_PREFIX}
diff --git a/plugins/banwords/banwords.py b/plugins/banwords/banwords.py
index 4f7f75cde..118b9631c 100644
--- a/plugins/banwords/banwords.py
+++ b/plugins/banwords/banwords.py
@@ -50,9 +50,7 @@ def __init__(self):
self.reply_action = conf.get("reply_action", "ignore")
logger.info("[Banwords] inited")
except Exception as e:
- logger.warn(
- "[Banwords] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/banwords ."
- )
+ logger.warn("[Banwords] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/banwords .")
raise e
def on_handle_context(self, e_context: EventContext):
@@ -72,9 +70,7 @@ def on_handle_context(self, e_context: EventContext):
return
elif self.action == "replace":
if self.searchr.ContainsAny(content):
- reply = Reply(
- ReplyType.INFO, "发言中包含敏感词,请重试: \n" + self.searchr.Replace(content)
- )
+ reply = Reply(ReplyType.INFO, "发言中包含敏感词,请重试: \n" + self.searchr.Replace(content))
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS
return
@@ -94,9 +90,7 @@ def on_decorate_reply(self, e_context: EventContext):
return
elif self.reply_action == "replace":
if self.searchr.ContainsAny(content):
- reply = Reply(
- ReplyType.INFO, "已替换回复中的敏感词: \n" + self.searchr.Replace(content)
- )
+ reply = Reply(ReplyType.INFO, "已替换回复中的敏感词: \n" + self.searchr.Replace(content))
e_context["reply"] = reply
e_context.action = EventAction.CONTINUE
return
diff --git a/plugins/bdunit/bdunit.py b/plugins/bdunit/bdunit.py
index 53aaec270..c99f2b045 100644
--- a/plugins/bdunit/bdunit.py
+++ b/plugins/bdunit/bdunit.py
@@ -82,9 +82,7 @@ def get_token(self):
Returns:
string: access_token
"""
- url = "https://aip.baidubce.com/oauth/2.0/token?client_id={}&client_secret={}&grant_type=client_credentials".format(
- self.api_key, self.secret_key
- )
+ url = "https://aip.baidubce.com/oauth/2.0/token?client_id={}&client_secret={}&grant_type=client_credentials".format(self.api_key, self.secret_key)
payload = ""
headers = {"Content-Type": "application/json", "Accept": "application/json"}
@@ -100,10 +98,7 @@ def getUnit(self, query):
:returns: UNIT 解析结果。如果解析失败,返回 None
"""
- url = (
- "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token="
- + self.access_token
- )
+ url = "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" + self.access_token
request = {
"query": query,
"user_id": str(get_mac())[:32],
@@ -130,10 +125,7 @@ def getUnit2(self, query):
:param query: 用户的指令字符串
:returns: UNIT 解析结果。如果解析失败,返回 None
"""
- url = (
- "https://aip.baidubce.com/rpc/2.0/unit/service/chat?access_token="
- + self.access_token
- )
+ url = "https://aip.baidubce.com/rpc/2.0/unit/service/chat?access_token=" + self.access_token
request = {"query": query, "user_id": str(get_mac())[:32]}
body = {
"log_id": str(uuid.uuid1()),
@@ -176,11 +168,7 @@ def hasIntent(self, parsed, intent):
if parsed and "result" in parsed and "response_list" in parsed["result"]:
response_list = parsed["result"]["response_list"]
for response in response_list:
- if (
- "schema" in response
- and "intent" in response["schema"]
- and response["schema"]["intent"] == intent
- ):
+ if "schema" in response and "intent" in response["schema"] and response["schema"]["intent"] == intent:
return True
return False
else:
@@ -204,12 +192,7 @@ def getSlots(self, parsed, intent=""):
logger.warning(e)
return []
for response in response_list:
- if (
- "schema" in response
- and "intent" in response["schema"]
- and "slots" in response["schema"]
- and response["schema"]["intent"] == intent
- ):
+ if "schema" in response and "intent" in response["schema"] and "slots" in response["schema"] and response["schema"]["intent"] == intent:
return response["schema"]["slots"]
return []
else:
@@ -245,11 +228,7 @@ def getSayByConfidence(self, parsed):
if (
"schema" in response
and "intent_confidence" in response["schema"]
- and (
- not answer
- or response["schema"]["intent_confidence"]
- > answer["schema"]["intent_confidence"]
- )
+ and (not answer or response["schema"]["intent_confidence"] > answer["schema"]["intent_confidence"])
):
answer = response
return answer["action_list"][0]["say"]
@@ -273,11 +252,7 @@ def getSay(self, parsed, intent=""):
logger.warning(e)
return ""
for response in response_list:
- if (
- "schema" in response
- and "intent" in response["schema"]
- and response["schema"]["intent"] == intent
- ):
+ if "schema" in response and "intent" in response["schema"] and response["schema"]["intent"] == intent:
try:
return response["action_list"][0]["say"]
except Exception as e:
diff --git a/plugins/dungeon/dungeon.py b/plugins/dungeon/dungeon.py
index 2e3cdf1ad..5b129d606 100644
--- a/plugins/dungeon/dungeon.py
+++ b/plugins/dungeon/dungeon.py
@@ -84,9 +84,7 @@ def on_handle_context(self, e_context: EventContext):
if len(clist) > 1:
story = clist[1]
else:
- story = (
- "你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。"
- )
+ story = "你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。"
self.games[sessionid] = StoryTeller(bot, sessionid, story)
reply = Reply(ReplyType.INFO, "冒险开始,你可以输入任意内容,让故事继续下去。故事背景是:" + story)
e_context["reply"] = reply
@@ -102,11 +100,7 @@ def get_help_text(self, **kwargs):
if kwargs.get("verbose") != True:
return help_text
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
- help_text = (
- f"{trigger_prefix}开始冒险 "
- + "背景故事: 开始一个基于{背景故事}的文字冒险,之后你的所有消息会协助完善这个故事。\n"
- + f"{trigger_prefix}停止冒险: 结束游戏。\n"
- )
+ help_text = f"{trigger_prefix}开始冒险 " + "背景故事: 开始一个基于{背景故事}的文字冒险,之后你的所有消息会协助完善这个故事。\n" + f"{trigger_prefix}停止冒险: 结束游戏。\n"
if kwargs.get("verbose") == True:
help_text += f"\n命令例子: '{trigger_prefix}开始冒险 你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。'"
return help_text
diff --git a/plugins/godcmd/godcmd.py b/plugins/godcmd/godcmd.py
index 9f0b3a583..97595e4d4 100644
--- a/plugins/godcmd/godcmd.py
+++ b/plugins/godcmd/godcmd.py
@@ -140,9 +140,7 @@ def get_help_text(isadmin, isgroup):
if plugins[plugin].enabled and not plugins[plugin].hidden:
namecn = plugins[plugin].namecn
help_text += "\n%s:" % namecn
- help_text += (
- PluginManager().instances[plugin].get_help_text(verbose=False).strip()
- )
+ help_text += PluginManager().instances[plugin].get_help_text(verbose=False).strip()
if ADMIN_COMMANDS and isadmin:
help_text += "\n\n管理员指令:\n"
@@ -191,9 +189,7 @@ def __init__(self):
COMMANDS["reset"]["alias"].append(custom_command)
self.password = gconf["password"]
- self.admin_users = gconf[
- "admin_users"
- ] # 预存的管理员账号,这些账号不需要认证。itchat的用户名每次都会变,不可用
+ self.admin_users = gconf["admin_users"] # 预存的管理员账号,这些账号不需要认证。itchat的用户名每次都会变,不可用
self.isrunning = True # 机器人是否运行中
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
@@ -215,7 +211,7 @@ def on_handle_context(self, e_context: EventContext):
reply.content = f"空指令,输入#help查看指令列表\n"
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS
- return
+ return
# msg = e_context['context']['msg']
channel = e_context["channel"]
user = e_context["context"]["receiver"]
@@ -248,11 +244,7 @@ def on_handle_context(self, e_context: EventContext):
if not plugincls.enabled:
continue
if query_name == name or query_name == plugincls.namecn:
- ok, result = True, PluginManager().instances[
- name
- ].get_help_text(
- isgroup=isgroup, isadmin=isadmin, verbose=True
- )
+ ok, result = True, PluginManager().instances[name].get_help_text(isgroup=isgroup, isadmin=isadmin, verbose=True)
break
if not ok:
result = "插件不存在或未启用"
@@ -285,11 +277,7 @@ def on_handle_context(self, e_context: EventContext):
if isgroup:
ok, result = False, "群聊不可执行管理员指令"
else:
- cmd = next(
- c
- for c, info in ADMIN_COMMANDS.items()
- if cmd in info["alias"]
- )
+ cmd = next(c for c, info in ADMIN_COMMANDS.items() if cmd in info["alias"])
if cmd == "stop":
self.isrunning = False
ok, result = True, "服务已暂停"
@@ -325,18 +313,14 @@ def on_handle_context(self, e_context: EventContext):
PluginManager().activate_plugins()
if len(new_plugins) > 0:
result += "\n发现新插件:\n"
- result += "\n".join(
- [f"{p.name}_v{p.version}" for p in new_plugins]
- )
+ result += "\n".join([f"{p.name}_v{p.version}" for p in new_plugins])
else:
result += ", 未发现新插件"
elif cmd == "setpri":
if len(args) != 2:
ok, result = False, "请提供插件名和优先级"
else:
- ok = PluginManager().set_plugin_priority(
- args[0], int(args[1])
- )
+ ok = PluginManager().set_plugin_priority(args[0], int(args[1]))
if ok:
result = "插件" + args[0] + "优先级已设置为" + args[1]
else:
diff --git a/plugins/hello/hello.py b/plugins/hello/hello.py
index 254b17254..fc8fe7051 100644
--- a/plugins/hello/hello.py
+++ b/plugins/hello/hello.py
@@ -33,9 +33,7 @@ def on_handle_context(self, e_context: EventContext):
if e_context["context"].type == ContextType.JOIN_GROUP:
e_context["context"].type = ContextType.TEXT
msg: ChatMessage = e_context["context"]["msg"]
- e_context[
- "context"
- ].content = f'请你随机使用一种风格说一句问候语来欢迎新用户"{msg.actual_user_nickname}"加入群聊。'
+ e_context["context"].content = f'请你随机使用一种风格说一句问候语来欢迎新用户"{msg.actual_user_nickname}"加入群聊。'
e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑
return
@@ -53,9 +51,7 @@ def on_handle_context(self, e_context: EventContext):
reply.type = ReplyType.TEXT
msg: ChatMessage = e_context["context"]["msg"]
if e_context["context"]["isgroup"]:
- reply.content = (
- f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}"
- )
+ reply.content = f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}"
else:
reply.content = f"Hello, {msg.from_user_nickname}"
e_context["reply"] = reply
diff --git a/plugins/keyword/config.json.template b/plugins/keyword/config.json.template
index 9a8332f3e..dbd5efe34 100644
--- a/plugins/keyword/config.json.template
+++ b/plugins/keyword/config.json.template
@@ -1,5 +1,5 @@
{
"keyword": {
- "关键字匹配": "测试成功"
+ "关键字匹配": "测试成功"
}
-}
\ No newline at end of file
+}
diff --git a/plugins/keyword/keyword.py b/plugins/keyword/keyword.py
index 376f748e9..97ebe26ac 100644
--- a/plugins/keyword/keyword.py
+++ b/plugins/keyword/keyword.py
@@ -41,9 +41,7 @@ def __init__(self):
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
logger.info("[keyword] inited.")
except Exception as e:
- logger.warn(
- "[keyword] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/keyword ."
- )
+ logger.warn("[keyword] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/keyword .")
raise e
def on_handle_context(self, e_context: EventContext):
diff --git a/plugins/plugin_manager.py b/plugins/plugin_manager.py
index b014e5f06..d2ee75e39 100644
--- a/plugins/plugin_manager.py
+++ b/plugins/plugin_manager.py
@@ -31,23 +31,14 @@ def wrapper(plugincls):
plugincls.desc = kwargs.get("desc")
plugincls.author = kwargs.get("author")
plugincls.path = self.current_plugin_path
- plugincls.version = (
- kwargs.get("version") if kwargs.get("version") != None else "1.0"
- )
- plugincls.namecn = (
- kwargs.get("namecn") if kwargs.get("namecn") != None else name
- )
- plugincls.hidden = (
- kwargs.get("hidden") if kwargs.get("hidden") != None else False
- )
+ plugincls.version = kwargs.get("version") if kwargs.get("version") != None else "1.0"
+ plugincls.namecn = kwargs.get("namecn") if kwargs.get("namecn") != None else name
+ plugincls.hidden = kwargs.get("hidden") if kwargs.get("hidden") != None else False
plugincls.enabled = True
if self.current_plugin_path == None:
raise Exception("Plugin path not set")
self.plugins[name.upper()] = plugincls
- logger.info(
- "Plugin %s_v%s registered, path=%s"
- % (name, plugincls.version, plugincls.path)
- )
+ logger.info("Plugin %s_v%s registered, path=%s" % (name, plugincls.version, plugincls.path))
return wrapper
@@ -62,9 +53,7 @@ def load_config(self):
if os.path.exists("./plugins/plugins.json"):
with open("./plugins/plugins.json", "r", encoding="utf-8") as f:
pconf = json.load(f)
- pconf["plugins"] = SortedDict(
- lambda k, v: v["priority"], pconf["plugins"], reverse=True
- )
+ pconf["plugins"] = SortedDict(lambda k, v: v["priority"], pconf["plugins"], reverse=True)
else:
modified = True
pconf = {"plugins": SortedDict(lambda k, v: v["priority"], reverse=True)}
@@ -90,26 +79,16 @@ def scan_plugins(self):
if plugin_path in self.loaded:
if self.loaded[plugin_path] == None:
logger.info("reload module %s" % plugin_name)
- self.loaded[plugin_path] = importlib.reload(
- sys.modules[import_path]
- )
- dependent_module_names = [
- name
- for name in sys.modules.keys()
- if name.startswith(import_path + ".")
- ]
+ self.loaded[plugin_path] = importlib.reload(sys.modules[import_path])
+ dependent_module_names = [name for name in sys.modules.keys() if name.startswith(import_path + ".")]
for name in dependent_module_names:
logger.info("reload module %s" % name)
importlib.reload(sys.modules[name])
else:
- self.loaded[plugin_path] = importlib.import_module(
- import_path
- )
+ self.loaded[plugin_path] = importlib.import_module(import_path)
self.current_plugin_path = None
except Exception as e:
- logger.exception(
- "Failed to import plugin %s: %s" % (plugin_name, e)
- )
+ logger.exception("Failed to import plugin %s: %s" % (plugin_name, e))
continue
pconf = self.pconf
news = [self.plugins[name] for name in self.plugins]
@@ -119,9 +98,7 @@ def scan_plugins(self):
rawname = plugincls.name
if rawname not in pconf["plugins"]:
modified = True
- logger.info(
- "Plugin %s not found in pconfig, adding to pconfig..." % name
- )
+ logger.info("Plugin %s not found in pconfig, adding to pconfig..." % name)
pconf["plugins"][rawname] = {
"enabled": plugincls.enabled,
"priority": plugincls.priority,
@@ -136,9 +113,7 @@ def scan_plugins(self):
def refresh_order(self):
for event in self.listening_plugins.keys():
- self.listening_plugins[event].sort(
- key=lambda name: self.plugins[name].priority, reverse=True
- )
+ self.listening_plugins[event].sort(key=lambda name: self.plugins[name].priority, reverse=True)
def activate_plugins(self): # 生成新开启的插件实例
failed_plugins = []
@@ -184,13 +159,8 @@ def load_plugins(self):
def emit_event(self, e_context: EventContext, *args, **kwargs):
if e_context.event in self.listening_plugins:
for name in self.listening_plugins[e_context.event]:
- if (
- self.plugins[name].enabled
- and e_context.action == EventAction.CONTINUE
- ):
- logger.debug(
- "Plugin %s triggered by event %s" % (name, e_context.event)
- )
+ if self.plugins[name].enabled and e_context.action == EventAction.CONTINUE:
+ logger.debug("Plugin %s triggered by event %s" % (name, e_context.event))
instance = self.instances[name]
instance.handlers[e_context.event](e_context, *args, **kwargs)
return e_context
@@ -262,9 +232,7 @@ def install_plugin(self, repo: str):
source = json.load(f)
if repo in source["repo"]:
repo = source["repo"][repo]["url"]
- match = re.match(
- r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo
- )
+ match = re.match(r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo)
if not match:
return False, "安装插件失败,source中的仓库地址不合法"
else:
diff --git a/plugins/role/role.py b/plugins/role/role.py
index 9788cc1cd..69c523347 100644
--- a/plugins/role/role.py
+++ b/plugins/role/role.py
@@ -69,13 +69,9 @@ def __init__(self):
logger.info("[Role] inited")
except Exception as e:
if isinstance(e, FileNotFoundError):
- logger.warn(
- f"[Role] init failed, {config_path} not found, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role ."
- )
+ logger.warn(f"[Role] init failed, {config_path} not found, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role .")
else:
- logger.warn(
- "[Role] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role ."
- )
+ logger.warn("[Role] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role .")
raise e
def get_role(self, name, find_closest=True, min_sim=0.35):
@@ -143,9 +139,7 @@ def on_handle_context(self, e_context: EventContext):
else:
help_text = f"未知角色类型。\n"
help_text += "目前的角色类型有: \n"
- help_text += (
- ",".join([self.tags[tag][0] for tag in self.tags]) + "\n"
- )
+ help_text += ",".join([self.tags[tag][0] for tag in self.tags]) + "\n"
else:
help_text = f"请输入角色类型。\n"
help_text += "目前的角色类型有: \n"
@@ -158,9 +152,7 @@ def on_handle_context(self, e_context: EventContext):
return
logger.debug("[Role] on_handle_context. content: %s" % content)
if desckey is not None:
- if len(clist) == 1 or (
- len(clist) > 1 and clist[1].lower() in ["help", "帮助"]
- ):
+ if len(clist) == 1 or (len(clist) > 1 and clist[1].lower() in ["help", "帮助"]):
reply = Reply(ReplyType.INFO, self.get_help_text(verbose=True))
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS
@@ -178,9 +170,7 @@ def on_handle_context(self, e_context: EventContext):
self.roles[role][desckey],
self.roles[role].get("wrapper", "%s"),
)
- reply = Reply(
- ReplyType.INFO, f"预设角色为 {role}:\n" + self.roles[role][desckey]
- )
+ reply = Reply(ReplyType.INFO, f"预设角色为 {role}:\n" + self.roles[role][desckey])
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS
elif customize == True:
@@ -199,17 +189,10 @@ def get_help_text(self, verbose=False, **kwargs):
if not verbose:
return help_text
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
- help_text = (
- f"使用方法:\n{trigger_prefix}角色"
- + " 预设角色名: 设定角色为{预设角色名}。\n"
- + f"{trigger_prefix}role"
- + " 预设角色名: 同上,但使用英文设定。\n"
- )
+ help_text = f"使用方法:\n{trigger_prefix}角色" + " 预设角色名: 设定角色为{预设角色名}。\n" + f"{trigger_prefix}role" + " 预设角色名: 同上,但使用英文设定。\n"
help_text += f"{trigger_prefix}设定扮演" + " 角色设定: 设定自定义角色人设为{角色设定}。\n"
help_text += f"{trigger_prefix}停止扮演: 清除设定的角色。\n"
- help_text += (
- f"{trigger_prefix}角色类型" + " 角色类型: 查看某类{角色类型}的所有预设角色,为所有时输出所有预设角色。\n"
- )
+ help_text += f"{trigger_prefix}角色类型" + " 角色类型: 查看某类{角色类型}的所有预设角色,为所有时输出所有预设角色。\n"
help_text += "\n目前的角色类型有: \n"
help_text += ",".join([self.tags[tag][0] for tag in self.tags]) + "。\n"
help_text += f"\n命令例子: \n{trigger_prefix}角色 写作助理\n"
diff --git a/plugins/tool/README.md b/plugins/tool/README.md
index 3ce92da47..9f91b288a 100644
--- a/plugins/tool/README.md
+++ b/plugins/tool/README.md
@@ -60,7 +60,7 @@
> 该tool每天返回内容相同
-#### 6.3. finance-news
+#### 6.3. finance-news
###### 获取实时的金融财政新闻
> 该工具需要解决browser tool 的google-chrome依赖安装
diff --git a/plugins/tool/tool.py b/plugins/tool/tool.py
index 2fb8a12bd..a3a3e76c6 100644
--- a/plugins/tool/tool.py
+++ b/plugins/tool/tool.py
@@ -82,9 +82,7 @@ def on_handle_context(self, e_context: EventContext):
return
elif content_list[1].startswith("reset"):
logger.debug("[tool]: remind")
- e_context[
- "context"
- ].content = "请你随机用一种聊天风格,提醒用户:如果想重置tool插件,reset之后不要加任何字符"
+ e_context["context"].content = "请你随机用一种聊天风格,提醒用户:如果想重置tool插件,reset之后不要加任何字符"
e_context.action = EventAction.BREAK
return
@@ -93,18 +91,14 @@ def on_handle_context(self, e_context: EventContext):
# Don't modify bot name
all_sessions = Bridge().get_bot("chat").sessions
- user_session = all_sessions.session_query(
- query, e_context["context"]["session_id"]
- ).messages
+ user_session = all_sessions.session_query(query, e_context["context"]["session_id"]).messages
# chatgpt-tool-hub will reply you with many tools
logger.debug("[tool]: just-go")
try:
_reply = self.app.ask(query, user_session)
e_context.action = EventAction.BREAK_PASS
- all_sessions.session_reply(
- _reply, e_context["context"]["session_id"]
- )
+ all_sessions.session_reply(_reply, e_context["context"]["session_id"])
except Exception as e:
logger.exception(e)
logger.error(str(e))
@@ -213,4 +207,4 @@ def _reset_app(self) -> App:
# filter not support tool
tool_list = self._filter_tool_list(tool_config.get("tools", []))
- return app.create_app(tools_list=tool_list, **app_kwargs)
\ No newline at end of file
+ return app.create_app(tools_list=tool_list, **app_kwargs)
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 000000000..abdab5733
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,8 @@
+[tool.black]
+line-length = 176
+target-version = ['py37']
+include = '\.pyi?$'
+extend-exclude = '.+/(dist|.venv|venv|build|lib)/.+'
+
+[tool.isort]
+profile = "black"
\ No newline at end of file
diff --git a/requirements-optional.txt b/requirements-optional.txt
index cfb52c953..0bcd1965d 100644
--- a/requirements-optional.txt
+++ b/requirements-optional.txt
@@ -7,6 +7,8 @@ gTTS>=2.3.1 # google text to speech
pyttsx3>=2.90 # pytsx text to speech
baidu_aip>=4.16.10 # baidu voice
# azure-cognitiveservices-speech # azure voice
+numpy<=1.24.2
+langid # language detect
#install plugin
dulwich
@@ -18,6 +20,7 @@ pysilk_mod>=1.6.0 # needed by send voice
# wechatmp
web.py
+wechatpy
# chatgpt-tool-hub plugin
diff --git a/requirements.txt b/requirements.txt
index 2980e2ded..e7cf147fd 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -17,3 +17,6 @@ lxml==4.9.2
pre-commit
webdriver_manager
selenium>=3.13.0
+
+Pillow
+pre-commit
diff --git a/translate/baidu/baidu_translate.py b/translate/baidu/baidu_translate.py
new file mode 100644
index 000000000..bf0a72143
--- /dev/null
+++ b/translate/baidu/baidu_translate.py
@@ -0,0 +1,46 @@
+# -*- coding: utf-8 -*-
+
+import random
+from hashlib import md5
+
+import requests
+
+from config import conf
+from translate.translator import Translator
+
+
+class BaiduTranslator(Translator):
+ def __init__(self) -> None:
+ super().__init__()
+ endpoint = "http://api.fanyi.baidu.com"
+ path = "/api/trans/vip/translate"
+ self.url = endpoint + path
+ self.appid = conf().get("baidu_translate_app_id")
+ self.appkey = conf().get("baidu_translate_app_key")
+
+ # For list of language codes, please refer to `https://api.fanyi.baidu.com/doc/21`, need to convert to ISO 639-1 codes
+ def translate(self, query: str, from_lang: str = "", to_lang: str = "en") -> str:
+ if not from_lang:
+ from_lang = "auto" # baidu suppport auto detect
+ salt = random.randint(32768, 65536)
+ sign = self.make_md5(self.appid + query + str(salt) + self.appkey)
+ headers = {"Content-Type": "application/x-www-form-urlencoded"}
+ payload = {"appid": self.appid, "q": query, "from": from_lang, "to": to_lang, "salt": salt, "sign": sign}
+
+ retry_cnt = 3
+ while retry_cnt:
+ r = requests.post(self.url, params=payload, headers=headers)
+ result = r.json()
+ if errcode := result.get("error_code", "52000") != "52000":
+ if errcode == "52001" or errcode == "52002":
+ retry_cnt -= 1
+ continue
+ else:
+ raise Exception(result["error_msg"])
+ else:
+ break
+ text = "\n".join([item["dst"] for item in result["trans_result"]])
+ return text
+
+ def make_md5(self, s, encoding="utf-8"):
+ return md5(s.encode(encoding)).hexdigest()
diff --git a/translate/factory.py b/translate/factory.py
new file mode 100644
index 000000000..ba80aa59d
--- /dev/null
+++ b/translate/factory.py
@@ -0,0 +1,6 @@
+def create_translator(voice_type):
+ if voice_type == "baidu":
+ from translate.baidu.baidu_translate import BaiduTranslator
+
+ return BaiduTranslator()
+ raise RuntimeError
diff --git a/translate/translator.py b/translate/translator.py
new file mode 100644
index 000000000..b394f4e4d
--- /dev/null
+++ b/translate/translator.py
@@ -0,0 +1,12 @@
+"""
+Voice service abstract class
+"""
+
+
+class Translator(object):
+ # please use https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes to specify language
+ def translate(self, query: str, from_lang: str = "", to_lang: str = "en") -> str:
+ """
+ Translate text from one language to another
+ """
+ raise NotImplementedError
diff --git a/voice/audio_convert.py b/voice/audio_convert.py
index ce0601da1..610170038 100644
--- a/voice/audio_convert.py
+++ b/voice/audio_convert.py
@@ -34,6 +34,20 @@ def get_pcm_from_wav(wav_path):
return wav.readframes(wav.getnframes())
+def any_to_mp3(any_path, mp3_path):
+ """
+ 把任意格式转成mp3文件
+ """
+ if any_path.endswith(".mp3"):
+ shutil.copy2(any_path, mp3_path)
+ return
+ if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"):
+ sil_to_wav(any_path, any_path)
+ any_path = mp3_path
+ audio = AudioSegment.from_file(any_path)
+ audio.export(mp3_path, format="mp3")
+
+
def any_to_wav(any_path, wav_path):
"""
把任意格式转成wav文件
@@ -41,11 +55,7 @@ def any_to_wav(any_path, wav_path):
if any_path.endswith(".wav"):
shutil.copy2(any_path, wav_path)
return
- if (
- any_path.endswith(".sil")
- or any_path.endswith(".silk")
- or any_path.endswith(".slk")
- ):
+ if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"):
return sil_to_wav(any_path, wav_path)
audio = AudioSegment.from_file(any_path)
audio.export(wav_path, format="wav")
@@ -55,11 +65,7 @@ def any_to_sil(any_path, sil_path):
"""
把任意格式转成sil文件
"""
- if (
- any_path.endswith(".sil")
- or any_path.endswith(".silk")
- or any_path.endswith(".slk")
- ):
+ if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"):
shutil.copy2(any_path, sil_path)
return 10000
audio = AudioSegment.from_file(any_path)
diff --git a/voice/azure/azure_voice.py b/voice/azure/azure_voice.py
index 3ee95043e..1a0a8ed3f 100644
--- a/voice/azure/azure_voice.py
+++ b/voice/azure/azure_voice.py
@@ -6,6 +6,7 @@
import time
import azure.cognitiveservices.speech as speechsdk
+from langid import classify
from bridge.reply import Reply, ReplyType
from common.log import logger
@@ -30,7 +31,15 @@ def __init__(self):
config = None
if not os.path.exists(config_path): # 如果没有配置文件,创建本地配置文件
config = {
- "speech_synthesis_voice_name": "zh-CN-XiaoxiaoNeural",
+ "speech_synthesis_voice_name": "zh-CN-XiaoxiaoNeural", # 识别不出时的默认语音
+ "auto_detect": True, # 是否自动检测语言
+ "speech_synthesis_zh": "zh-CN-XiaozhenNeural",
+ "speech_synthesis_en": "en-US-JacobNeural",
+ "speech_synthesis_ja": "ja-JP-AoiNeural",
+ "speech_synthesis_ko": "ko-KR-SoonBokNeural",
+ "speech_synthesis_de": "de-DE-LouisaNeural",
+ "speech_synthesis_fr": "fr-FR-BrigitteNeural",
+ "speech_synthesis_es": "es-ES-LaiaNeural",
"speech_recognition_language": "zh-CN",
}
with open(config_path, "w") as fw:
@@ -38,59 +47,47 @@ def __init__(self):
else:
with open(config_path, "r") as fr:
config = json.load(fr)
+ self.config = config
self.api_key = conf().get("azure_voice_api_key")
self.api_region = conf().get("azure_voice_region")
- self.speech_config = speechsdk.SpeechConfig(
- subscription=self.api_key, region=self.api_region
- )
- self.speech_config.speech_synthesis_voice_name = config[
- "speech_synthesis_voice_name"
- ]
- self.speech_config.speech_recognition_language = config[
- "speech_recognition_language"
- ]
+ self.speech_config = speechsdk.SpeechConfig(subscription=self.api_key, region=self.api_region)
+ self.speech_config.speech_synthesis_voice_name = self.config["speech_synthesis_voice_name"]
+ self.speech_config.speech_recognition_language = self.config["speech_recognition_language"]
except Exception as e:
logger.warn("AzureVoice init failed: %s, ignore " % e)
def voiceToText(self, voice_file):
audio_config = speechsdk.AudioConfig(filename=voice_file)
- speech_recognizer = speechsdk.SpeechRecognizer(
- speech_config=self.speech_config, audio_config=audio_config
- )
+ speech_recognizer = speechsdk.SpeechRecognizer(speech_config=self.speech_config, audio_config=audio_config)
result = speech_recognizer.recognize_once()
if result.reason == speechsdk.ResultReason.RecognizedSpeech:
- logger.info(
- "[Azure] voiceToText voice file name={} text={}".format(
- voice_file, result.text
- )
- )
+ logger.info("[Azure] voiceToText voice file name={} text={}".format(voice_file, result.text))
reply = Reply(ReplyType.TEXT, result.text)
else:
- logger.error(
- "[Azure] voiceToText error, result={}, canceldetails={}".format(
- result, result.cancellation_details
- )
- )
+ logger.error("[Azure] voiceToText error, result={}, canceldetails={}".format(result, result.cancellation_details))
reply = Reply(ReplyType.ERROR, "抱歉,语音识别失败")
return reply
def textToVoice(self, text):
- fileName = TmpDir().path() + "reply-" + str(int(time.time())) + ".wav"
+ if self.config.get("auto_detect"):
+ lang = classify(text)[0]
+ key = "speech_synthesis_" + lang
+ if key in self.config:
+ logger.info("[Azure] textToVoice auto detect language={}, voice={}".format(lang, self.config[key]))
+ self.speech_config.speech_synthesis_voice_name = self.config[key]
+ else:
+ self.speech_config.speech_synthesis_voice_name = self.config["speech_synthesis_voice_name"]
+ else:
+ self.speech_config.speech_synthesis_voice_name = self.config["speech_synthesis_voice_name"]
+ # Avoid the same filename under multithreading
+ fileName = TmpDir().path() + "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7FFFFFFF) + ".wav"
audio_config = speechsdk.AudioConfig(filename=fileName)
- speech_synthesizer = speechsdk.SpeechSynthesizer(
- speech_config=self.speech_config, audio_config=audio_config
- )
+ speech_synthesizer = speechsdk.SpeechSynthesizer(speech_config=self.speech_config, audio_config=audio_config)
result = speech_synthesizer.speak_text(text)
if result.reason == speechsdk.ResultReason.SynthesizingAudioCompleted:
- logger.info(
- "[Azure] textToVoice text={} voice file name={}".format(text, fileName)
- )
+ logger.info("[Azure] textToVoice text={} voice file name={}".format(text, fileName))
reply = Reply(ReplyType.VOICE, fileName)
else:
- logger.error(
- "[Azure] textToVoice error, result={}, canceldetails={}".format(
- result, result.cancellation_details
- )
- )
+ logger.error("[Azure] textToVoice error, result={}, canceldetails={}".format(result, result.cancellation_details))
reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败")
return reply
diff --git a/voice/azure/config.json.template b/voice/azure/config.json.template
index 2dc2176f9..8f3f546f7 100644
--- a/voice/azure/config.json.template
+++ b/voice/azure/config.json.template
@@ -1,4 +1,12 @@
{
"speech_synthesis_voice_name": "zh-CN-XiaoxiaoNeural",
+ "auto_detect": true,
+ "speech_synthesis_zh": "zh-CN-YunxiNeural",
+ "speech_synthesis_en": "en-US-JacobNeural",
+ "speech_synthesis_ja": "ja-JP-AoiNeural",
+ "speech_synthesis_ko": "ko-KR-SoonBokNeural",
+ "speech_synthesis_de": "de-DE-LouisaNeural",
+ "speech_synthesis_fr": "fr-FR-BrigitteNeural",
+ "speech_synthesis_es": "es-ES-LaiaNeural",
"speech_recognition_language": "zh-CN"
}
diff --git a/voice/baidu/baidu_voice.py b/voice/baidu/baidu_voice.py
index ccde6c4d1..406157b96 100644
--- a/voice/baidu/baidu_voice.py
+++ b/voice/baidu/baidu_voice.py
@@ -82,12 +82,11 @@ def textToVoice(self, text):
{"spd": self.spd, "pit": self.pit, "vol": self.vol, "per": self.per},
)
if not isinstance(result, dict):
- fileName = TmpDir().path() + "reply-" + str(int(time.time())) + ".mp3"
+ # Avoid the same filename under multithreading
+ fileName = TmpDir().path() + "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7FFFFFFF) + ".mp3"
with open(fileName, "wb") as f:
f.write(result)
- logger.info(
- "[Baidu] textToVoice text={} voice file name={}".format(text, fileName)
- )
+ logger.info("[Baidu] textToVoice text={} voice file name={}".format(text, fileName))
reply = Reply(ReplyType.VOICE, fileName)
else:
logger.error("[Baidu] textToVoice error={}".format(result))
diff --git a/voice/voice_factory.py b/voice/factory.py
similarity index 100%
rename from voice/voice_factory.py
rename to voice/factory.py
diff --git a/voice/google/google_voice.py b/voice/google/google_voice.py
index 4f7b8ade3..6dcadad3d 100644
--- a/voice/google/google_voice.py
+++ b/voice/google/google_voice.py
@@ -24,11 +24,7 @@ def voiceToText(self, voice_file):
audio = self.recognizer.record(source)
try:
text = self.recognizer.recognize_google(audio, language="zh-CN")
- logger.info(
- "[Google] voiceToText text={} voice file name={}".format(
- text, voice_file
- )
- )
+ logger.info("[Google] voiceToText text={} voice file name={}".format(text, voice_file))
reply = Reply(ReplyType.TEXT, text)
except speech_recognition.UnknownValueError:
reply = Reply(ReplyType.ERROR, "抱歉,我听不懂")
@@ -39,12 +35,11 @@ def voiceToText(self, voice_file):
def textToVoice(self, text):
try:
- mp3File = TmpDir().path() + "reply-" + str(int(time.time())) + ".mp3"
+ # Avoid the same filename under multithreading
+ mp3File = TmpDir().path() + "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7FFFFFFF) + ".mp3"
tts = gTTS(text=text, lang="zh")
tts.save(mp3File)
- logger.info(
- "[Google] textToVoice text={} voice file name={}".format(text, mp3File)
- )
+ logger.info("[Google] textToVoice text={} voice file name={}".format(text, mp3File))
reply = Reply(ReplyType.VOICE, mp3File)
except Exception as e:
reply = Reply(ReplyType.ERROR, str(e))
diff --git a/voice/openai/openai_voice.py b/voice/openai/openai_voice.py
index 06c221b21..b02d92651 100644
--- a/voice/openai/openai_voice.py
+++ b/voice/openai/openai_voice.py
@@ -22,11 +22,7 @@ def voiceToText(self, voice_file):
result = openai.Audio.transcribe("whisper-1", file)
text = result["text"]
reply = Reply(ReplyType.TEXT, text)
- logger.info(
- "[Openai] voiceToText text={} voice file name={}".format(
- text, voice_file
- )
- )
+ logger.info("[Openai] voiceToText text={} voice file name={}".format(text, voice_file))
except Exception as e:
reply = Reply(ReplyType.ERROR, str(e))
finally:
diff --git a/voice/pytts/pytts_voice.py b/voice/pytts/pytts_voice.py
index 072e28b41..bd70086d3 100644
--- a/voice/pytts/pytts_voice.py
+++ b/voice/pytts/pytts_voice.py
@@ -2,6 +2,8 @@
pytts voice service (offline)
"""
+import os
+import sys
import time
import pyttsx3
@@ -20,19 +22,42 @@ def __init__(self):
self.engine.setProperty("rate", 125)
# 音量
self.engine.setProperty("volume", 1.0)
- for voice in self.engine.getProperty("voices"):
- if "Chinese" in voice.name:
- self.engine.setProperty("voice", voice.id)
+ if sys.platform == "win32":
+ for voice in self.engine.getProperty("voices"):
+ if "Chinese" in voice.name:
+ self.engine.setProperty("voice", voice.id)
+ else:
+ self.engine.setProperty("voice", "zh")
+ # If the problem of espeak is fixed, using runAndWait() and remove this startLoop()
+ # TODO: check if this is work on win32
+ self.engine.startLoop(useDriverLoop=False)
def textToVoice(self, text):
try:
- wavFile = TmpDir().path() + "reply-" + str(int(time.time())) + ".wav"
+ # Avoid the same filename under multithreading
+ wavFileName = "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7FFFFFFF) + ".wav"
+ wavFile = TmpDir().path() + wavFileName
+ logger.info("[Pytts] textToVoice text={} voice file name={}".format(text, wavFile))
+
self.engine.save_to_file(text, wavFile)
- self.engine.runAndWait()
- logger.info(
- "[Pytts] textToVoice text={} voice file name={}".format(text, wavFile)
- )
+
+ if sys.platform == "win32":
+ self.engine.runAndWait()
+ else:
+ # In ubuntu, runAndWait do not really wait until the file created.
+ # It will return once the task queue is empty, but the task is still running in coroutine.
+ # And if you call runAndWait() and time.sleep() twice, it will stuck, so do not use this.
+ # If you want to fix this, add self._proxy.setBusy(True) in line 127 in espeak.py, at the beginning of the function save_to_file.
+ # self.engine.runAndWait()
+
+ # Before espeak fix this problem, we iterate the generator and control the waiting by ourself.
+ # But this is not the canonical way to use it, for example if the file already exists it also cannot wait.
+ self.engine.iterate()
+ while self.engine.isBusy() or wavFileName not in os.listdir(TmpDir().path()):
+ time.sleep(0.1)
+
reply = Reply(ReplyType.VOICE, wavFile)
+
except Exception as e:
reply = Reply(ReplyType.ERROR, str(e))
finally: