diff --git a/.flake8 b/.flake8 index c614a6302..0af19c5f1 100644 --- a/.flake8 +++ b/.flake8 @@ -1,5 +1,5 @@ [flake8] -max-line-length = 88 +max-line-length = 176 select = E303,W293,W291,W292,E305,E231,E302 exclude = .tox, diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md deleted file mode 100644 index 7d697ec53..000000000 --- a/.github/ISSUE_TEMPLATE.md +++ /dev/null @@ -1,31 +0,0 @@ -### 前置确认 - -1. 网络能够访问openai接口 -2. python 已安装:版本在 3.7 ~ 3.10 之间 -3. `git pull` 拉取最新代码 -4. 执行`pip3 install -r requirements.txt`,检查依赖是否满足 -5. 拓展功能请执行`pip3 install -r requirements-optional.txt`,检查依赖是否满足 -6. 在已有 issue 中未搜索到类似问题 -7. [FAQS](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs) 中无类似问题 - - -### 问题描述 - -> 简要说明、截图、复现步骤等,也可以是需求或想法 - - - - -### 终端日志 (如有报错) - -``` -[在此处粘贴终端日志, 可在主目录下`run.log`文件中找到] -``` - - - -### 环境 - - - 操作系统类型 (Mac/Windows/Linux): - - Python版本 ( 执行 `python3 -V` ): - - pip版本 ( 依赖问题此项必填,执行 `pip3 -V`): diff --git a/.github/ISSUE_TEMPLATE/1.bug.yml b/.github/ISSUE_TEMPLATE/1.bug.yml new file mode 100644 index 000000000..2f762c076 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/1.bug.yml @@ -0,0 +1,133 @@ +name: Bug report 🐛 +description: 项目运行中遇到的Bug或问题。 +labels: ['status: needs check'] +body: + - type: markdown + attributes: + value: | + ### ⚠️ 前置确认 + 1. 网络能够访问openai接口 + 2. python 已安装:版本在 3.7 ~ 3.10 之间 + 3. `git pull` 拉取最新代码 + 4. 执行`pip3 install -r requirements.txt`,检查依赖是否满足 + 5. 拓展功能请执行`pip3 install -r requirements-optional.txt`,检查依赖是否满足 + 6. [FAQS](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs) 中无类似问题 + - type: checkboxes + attributes: + label: 前置确认 + options: + - label: 我确认我运行的是最新版本的代码,并且安装了所需的依赖,在[FAQS](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs)中也未找到类似问题。 + required: true + - type: checkboxes + attributes: + label: ⚠️ 搜索issues中是否已存在类似问题 + description: > + 请在 [历史issue](https://github.com/zhayujie/chatgpt-on-wechat/issues) 中清空输入框,搜索你的问题 + 或相关日志的关键词来查找是否存在类似问题。 + options: + - label: 我已经搜索过issues和disscussions,没有跟我遇到的问题相关的issue + required: true + - type: markdown + attributes: + value: | + 请在上方的`title`中填写你对你所遇到问题的简略总结,这将帮助其他人更好的找到相似问题,谢谢❤️。 + - type: dropdown + attributes: + label: 操作系统类型? + description: > + 请选择你运行程序的操作系统类型。 + options: + - Windows + - Linux + - MacOS + - Docker + - Railway + - Windows Subsystem for Linux (WSL) + - Other (请在问题中说明) + validations: + required: true + - type: dropdown + attributes: + label: 运行的python版本是? + description: | + 请选择你运行程序的`python`版本。 + 注意:在`python 3.7`中,有部分可选依赖无法安装。 + 经过长时间的观察,我们认为`python 3.8`是兼容性最好的版本。 + `python 3.7`~`python 3.10`以外版本的issue,将视情况直接关闭。 + options: + - python 3.7 + - python 3.8 + - python 3.9 + - python 3.10 + - other + validations: + required: true + - type: dropdown + attributes: + label: 使用的chatgpt-on-wechat版本是? + description: | + 请确保你使用的是 [releases](https://github.com/zhayujie/chatgpt-on-wechat/releases) 中的最新版本。 + 如果你使用git, 请使用`git branch`命令来查看分支。 + options: + - Latest Release + - Master (branch) + validations: + required: true + - type: dropdown + attributes: + label: 运行的`channel`类型是? + description: | + 请确保你正确配置了该`channel`所需的配置项,所有可选的配置项都写在了[该文件中](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py),请将所需配置项填写在根目录下的`config.json`文件中。 + options: + - wx(个人微信, itchat) + - wxy(个人微信, wechaty) + - wechatmp(公众号, 订阅号) + - wechatmp_service(公众号, 服务号) + - terminal + - other + validations: + required: true + - type: textarea + attributes: + label: 复现步骤 🕹 + description: | + **⚠️ 不能复现将会关闭issue.** + - type: textarea + attributes: + label: 问题描述 😯 + description: 详细描述出现的问题,或提供有关截图。 + - type: textarea + attributes: + label: 终端日志 📒 + description: | + 在此处粘贴终端日志,可在主目录下`run.log`文件中找到,这会帮助我们更好的分析问题,注意隐去你的API key。 + 如果在配置文件中加入`"debug": true`,打印出的日志会更有帮助。 + +
+ 示例 + ```log + [DEBUG][2023-04-16 00:23:22][plugin_manager.py:157] - Plugin SUMMARY triggered by event Event.ON_HANDLE_CONTEXT + [DEBUG][2023-04-16 00:23:22][main.py:221] - [Summary] on_handle_context. content: $总结前100条消息 + [DEBUG][2023-04-16 00:23:24][main.py:240] - [Summary] limit: 100, duration: -1 seconds + [ERROR][2023-04-16 00:23:24][chat_channel.py:244] - Worker return exception: name 'start_date' is not defined + Traceback (most recent call last): + File "C:\ProgramData\Anaconda3\lib\concurrent\futures\thread.py", line 57, in run + result = self.fn(*self.args, **self.kwargs) + File "D:\project\chatgpt-on-wechat\channel\chat_channel.py", line 132, in _handle + reply = self._generate_reply(context) + File "D:\project\chatgpt-on-wechat\channel\chat_channel.py", line 142, in _generate_reply + e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, { + File "D:\project\chatgpt-on-wechat\plugins\plugin_manager.py", line 159, in emit_event + instance.handlers[e_context.event](e_context, *args, **kwargs) + File "D:\project\chatgpt-on-wechat\plugins\summary\main.py", line 255, in on_handle_context + records = self._get_records(session_id, start_time, limit) + File "D:\project\chatgpt-on-wechat\plugins\summary\main.py", line 96, in _get_records + c.execute("SELECT * FROM chat_records WHERE sessionid=? and timestamp>? ORDER BY timestamp DESC LIMIT ?", (session_id, start_date, limit)) + NameError: name 'start_date' is not defined + [INFO][2023-04-16 00:23:36][app.py:14] - signal 2 received, exiting... + ``` +
+ value: | + ```log + <此处粘贴终端日志> + ``` \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/2.feature.yml b/.github/ISSUE_TEMPLATE/2.feature.yml new file mode 100644 index 000000000..bbf0888a2 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/2.feature.yml @@ -0,0 +1,28 @@ +name: Feature request 🚀 +description: 提出你对项目的新想法或建议。 +labels: ['status: needs check'] +body: + - type: markdown + attributes: + value: | + 请在上方的`title`中填写简略总结,谢谢❤️。 + - type: checkboxes + attributes: + label: ⚠️ 搜索是否存在类似issue + description: > + 请在 [历史issue](https://github.com/zhayujie/chatgpt-on-wechat/issues) 中清空输入框,搜索关键词查找是否存在相似issue。 + options: + - label: 我已经搜索过issues和disscussions,没有发现相似issue + required: true + - type: textarea + attributes: + label: 总结 + description: 描述feature的功能。 + - type: textarea + attributes: + label: 举例 + description: 提供聊天示例,草图或相关网址。 + - type: textarea + attributes: + label: 动机 + description: 描述你提出该feature的动机,比如没有这项feature对你的使用造成了怎样的影响。 请提供更详细的场景描述,这可能会帮助我们发现并提出更好的解决方案。 \ No newline at end of file diff --git a/README.md b/README.md index bb86b4130..2fad81bc0 100644 --- a/README.md +++ b/README.md @@ -2,27 +2,29 @@ > ChatGPT近期以强大的对话和信息整合能力风靡全网,可以写代码、改论文、讲故事,几乎无所不能,这让人不禁有个大胆的想法,能否用他的对话模型把我们的微信打造成一个智能机器人,可以在与好友对话中给出意想不到的回应,而且再也不用担心女朋友影响我们 ~~打游戏~~ 工作了。 +最新版本支持的功能如下: -基于ChatGPT的微信聊天机器人,通过 [ChatGPT](https://github.com/openai/openai-python) 接口生成对话内容,使用 [itchat](https://github.com/littlecodersh/ItChat) 实现微信消息的接收和自动回复。已实现的特性如下: - -- [x] **文本对话:** 接收私聊及群组中的微信消息,使用ChatGPT生成回复内容,完成自动回复 -- [x] **规则定制化:** 支持私聊中按指定规则触发自动回复,支持对群组设置自动回复白名单 -- [x] **图片生成:** 支持根据描述生成图片,支持图片修复 -- [x] **上下文记忆**:支持多轮对话记忆,且为每个好友维护独立的上下会话 -- [x] **语音识别:** 支持接收和处理语音消息,通过文字或语音回复 -- [x] **插件化:** 支持个性化插件,提供角色扮演、文字冒险、与操作系统交互、访问网络数据等能力 - -> 目前支持微信和微信公众号部署,欢迎接入更多应用,参考 [Terminal代码](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/terminal/terminal_channel.py)实现接收和发送消息逻辑即可接入。 同时欢迎增加新的插件,参考 [插件说明文档](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins)。 +- [x] **基础对话:** 私聊及群聊的消息智能回复,支持多轮会话上下文记忆,支持 GPT-3,GPT-3.5,GPT-4模型 +- [x] **语音识别:** 可识别语音消息,通过文字或语音回复,支持 azure, baidu, google, openai等多种语音模型 +- [x] **图片生成:** 支持图片生成 和 照片修复,可选择 Dell-E, stable diffusion, replicate 模型 +- [x] **丰富插件:** 支持个性化插件扩展,已实现多角色切换、文字冒险、敏感词过滤、聊天记录总结等插件 +- [X] **Tool工具:** 与操作系统和互联网交互,支持最新信息搜索、数学计算、天气和资讯查询、网页总结,基于 [chatgpt-tool-hub](https://github.com/goldfishh/chatgpt-tool-hub) 实现 +> 目前已支持 个人微信 和 微信公众号,欢迎接入更多应用,参考 [Terminal代码](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/terminal/terminal_channel.py)实现接收和发送消息逻辑即可接入。 同时欢迎增加新的插件,参考 [插件说明文档](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins)。 **一键部署:** [![Deploy on Railway](https://railway.app/button.svg)](https://railway.app/template/qApznZ?referralCode=RC3znh) +# 演示 + +https://user-images.githubusercontent.com/26161723/233777277-e3b9928e-b88f-43e2-b0e0-3cbc923bc799.mp4 + +Demo made by [Visionn](https://www.wangpc.cc/) # 更新日志 ->**2023.04.05:** 支持微信个人号部署,兼容角色扮演等预设插件,[使用文档](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/wechatmp/README.md)。(contributed by [@JS00000](https://github.com/JS00000) in [#686](https://github.com/zhayujie/chatgpt-on-wechat/pull/686)) +>**2023.04.05:** 支持微信公众号部署,兼容角色扮演等预设插件,[使用文档](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/wechatmp/README.md)。(contributed by [@JS00000](https://github.com/JS00000) in [#686](https://github.com/zhayujie/chatgpt-on-wechat/pull/686)) >**2023.04.05:** 增加能让ChatGPT使用工具的`tool`插件,[使用文档](https://github.com/goldfishh/chatgpt-on-wechat/blob/master/plugins/tool/README.md)。工具相关issue可反馈至[chatgpt-tool-hub](https://github.com/goldfishh/chatgpt-tool-hub)。(contributed by [@goldfishh](https://github.com/goldfishh) in [#663](https://github.com/zhayujie/chatgpt-on-wechat/pull/663)) @@ -32,28 +34,7 @@ >**2023.03.02:** 接入[ChatGPT API](https://platform.openai.com/docs/guides/chat) (gpt-3.5-turbo),默认使用该模型进行对话,需升级openai依赖 (`pip3 install --upgrade openai`)。网络问题参考 [#351](https://github.com/zhayujie/chatgpt-on-wechat/issues/351) ->**2023.02.09:** 扫码登录存在封号风险,请谨慎使用,参考[#58](https://github.com/AutumnWhj/ChatGPT-wechat-bot/issues/158) - ->**2023.02.05:** 在openai官方接口方案中 (GPT-3模型) 实现上下文对话 - ->**2022.12.18:** 支持根据描述生成图片并发送,openai版本需大于0.25.0 - ->**2022.12.17:** 原来的方案是从 [ChatGPT页面](https://chat.openai.com/chat) 获取session_token,使用 [revChatGPT](https://github.com/acheong08/ChatGPT) 直接访问web接口,但随着ChatGPT接入Cloudflare人机验证,这一方案难以在服务器顺利运行。 所以目前使用的方案是调用 OpenAI 官方提供的 [API](https://beta.openai.com/docs/api-reference/introduction),回复质量上基本接近于ChatGPT的内容,劣势是暂不支持有上下文记忆的对话,优势是稳定性和响应速度较好。 - -# 使用效果 - -### 个人聊天 - -![single-chat-sample.jpg](docs/images/single-chat-sample.jpg) - -### 群组聊天 - -![group-chat-sample.jpg](docs/images/group-chat-sample.jpg) - -### 图片生成 - -![group-chat-sample.jpg](docs/images/image-create-sample.jpg) - +>**2023.02.09:** 扫码登录存在账号限制风险,请谨慎使用,参考[#58](https://github.com/AutumnWhj/ChatGPT-wechat-bot/issues/158) # 快速开始 @@ -63,7 +44,7 @@ 前往 [OpenAI注册页面](https://beta.openai.com/signup) 创建账号,参考这篇 [教程](https://www.pythonthree.com/register-openai-chatgpt/) 可以通过虚拟手机号来接收验证码。创建完账号则前往 [API管理页面](https://beta.openai.com/account/api-keys) 创建一个 API Key 并保存下来,后面需要在项目中配置这个key。 -> 项目中使用的对话模型是 davinci,计费方式是约每 750 字 (包含请求和回复) 消耗 $0.02,图片生成是每张消耗 $0.016,账号创建有免费的 $18 额度 (更新3.25: 最新注册的已经无免费额度了),使用完可以更换邮箱重新注册。 +> 项目中默认使用的对话模型是 gpt3.5 turbo,计费方式是约每 500 汉字 (包含请求和回复) 消耗 $0.002,图片生成是每张消耗 $0.016。 ### 2.运行环境 @@ -203,7 +184,7 @@ nohup python3 app.py & tail -f nohup.out # 在后台运行程序并通 参考文档 [Docker部署](https://github.com/limccn/chatgpt-on-wechat/wiki/Docker%E9%83%A8%E7%BD%B2) (Contributed by [limccn](https://github.com/limccn))。 -### 4. Railway部署(✅推荐) +### 4. Railway部署 (✅推荐) > Railway每月提供5刀和最多500小时的免费额度。 1. 进入 [Railway](https://railway.app/template/qApznZ?referralCode=RC3znh)。 2. 点击 `Deploy Now` 按钮。 diff --git a/app.py b/app.py index e11f46c2a..637b6e462 100644 --- a/app.py +++ b/app.py @@ -19,6 +19,7 @@ def func(_signo, _stack_frame): if callable(old_handler): # check old_handler return old_handler(_signo, _stack_frame) sys.exit(0) + signal.signal(_signo, func) diff --git a/bot/baidu/baidu_unit_bot.py b/bot/baidu/baidu_unit_bot.py index d8a0aca11..f7714e4f4 100644 --- a/bot/baidu/baidu_unit_bot.py +++ b/bot/baidu/baidu_unit_bot.py @@ -10,10 +10,7 @@ class BaiduUnitBot(Bot): def reply(self, query, context=None): token = self.get_token() - url = ( - "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" - + token - ) + url = "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" + token post_data = ( '{"version":"3.0","service_id":"S73177","session_id":"","log_id":"7758521","skill_ids":["1221886"],"request":{"terminal_id":"88888","query":"' + query @@ -32,12 +29,7 @@ def reply(self, query, context=None): def get_token(self): access_key = "YOUR_ACCESS_KEY" secret_key = "YOUR_SECRET_KEY" - host = ( - "https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=" - + access_key - + "&client_secret=" - + secret_key - ) + host = "https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=" + access_key + "&client_secret=" + secret_key response = requests.get(host) if response: print(response.json()) diff --git a/bot/chatgpt/chat_gpt_bot.py b/bot/chatgpt/chat_gpt_bot.py index d8e4b0e26..b045311a3 100644 --- a/bot/chatgpt/chat_gpt_bot.py +++ b/bot/chatgpt/chat_gpt_bot.py @@ -30,23 +30,15 @@ def __init__(self): if conf().get("rate_limit_chatgpt"): self.tb4chatgpt = TokenBucket(conf().get("rate_limit_chatgpt", 20)) - self.sessions = SessionManager( - ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo" - ) + self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo") self.args = { "model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称 "temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性 # "max_tokens":4096, # 回复最大的字符数 "top_p": 1, - "frequency_penalty": conf().get( - "frequency_penalty", 0.0 - ), # [-2,2]之间,该值越大则更倾向于产生不同的内容 - "presence_penalty": conf().get( - "presence_penalty", 0.0 - ), # [-2,2]之间,该值越大则更倾向于产生不同的内容 - "request_timeout": conf().get( - "request_timeout", None - ), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间 + "frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 + "presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 + "request_timeout": conf().get("request_timeout", None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间 "timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试 } @@ -87,15 +79,10 @@ def reply(self, query, context=None): reply_content["completion_tokens"], ) ) - if ( - reply_content["completion_tokens"] == 0 - and len(reply_content["content"]) > 0 - ): + if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0: reply = Reply(ReplyType.ERROR, reply_content["content"]) elif reply_content["completion_tokens"] > 0: - self.sessions.session_reply( - reply_content["content"], session_id, reply_content["total_tokens"] - ) + self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"]) reply = Reply(ReplyType.TEXT, reply_content["content"]) else: reply = Reply(ReplyType.ERROR, reply_content["content"]) @@ -126,9 +113,7 @@ def reply_text(self, session: ChatGPTSession, api_key=None, retry_count=0) -> di if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token(): raise openai.error.RateLimitError("RateLimitError: rate limit exceeded") # if api_key == None, the default openai.api_key will be used - response = openai.ChatCompletion.create( - api_key=api_key, messages=session.messages, **self.args - ) + response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **self.args) # logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"])) return { "total_tokens": response["usage"]["total_tokens"], diff --git a/bot/chatgpt/chat_gpt_session.py b/bot/chatgpt/chat_gpt_session.py index 525793ffa..e6c319b36 100644 --- a/bot/chatgpt/chat_gpt_session.py +++ b/bot/chatgpt/chat_gpt_session.py @@ -25,9 +25,7 @@ def discard_exceeding(self, max_tokens, cur_tokens=None): precise = False if cur_tokens is None: raise e - logger.debug( - "Exception when counting tokens precisely for query: {}".format(e) - ) + logger.debug("Exception when counting tokens precisely for query: {}".format(e)) while cur_tokens > max_tokens: if len(self.messages) > 2: self.messages.pop(1) @@ -39,16 +37,10 @@ def discard_exceeding(self, max_tokens, cur_tokens=None): cur_tokens = cur_tokens - max_tokens break elif len(self.messages) == 2 and self.messages[1]["role"] == "user": - logger.warn( - "user message exceed max_tokens. total_tokens={}".format(cur_tokens) - ) + logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens)) break else: - logger.debug( - "max_tokens={}, total_tokens={}, len(messages)={}".format( - max_tokens, cur_tokens, len(self.messages) - ) - ) + logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages))) break if precise: cur_tokens = self.calc_tokens() @@ -75,17 +67,13 @@ def num_tokens_from_messages(messages, model): elif model == "gpt-4": return num_tokens_from_messages(messages, model="gpt-4-0314") elif model == "gpt-3.5-turbo-0301": - tokens_per_message = ( - 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n - ) + tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n tokens_per_name = -1 # if there's a name, the role is omitted elif model == "gpt-4-0314": tokens_per_message = 3 tokens_per_name = 1 else: - logger.warn( - f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo-0301." - ) + logger.warn(f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo-0301.") return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301") num_tokens = 0 for message in messages: diff --git a/bot/openai/open_ai_bot.py b/bot/openai/open_ai_bot.py index 1cfbf10d8..160562526 100644 --- a/bot/openai/open_ai_bot.py +++ b/bot/openai/open_ai_bot.py @@ -28,23 +28,15 @@ def __init__(self): if proxy: openai.proxy = proxy - self.sessions = SessionManager( - OpenAISession, model=conf().get("model") or "text-davinci-003" - ) + self.sessions = SessionManager(OpenAISession, model=conf().get("model") or "text-davinci-003") self.args = { "model": conf().get("model") or "text-davinci-003", # 对话模型的名称 "temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性 "max_tokens": 1200, # 回复最大的字符数 "top_p": 1, - "frequency_penalty": conf().get( - "frequency_penalty", 0.0 - ), # [-2,2]之间,该值越大则更倾向于产生不同的内容 - "presence_penalty": conf().get( - "presence_penalty", 0.0 - ), # [-2,2]之间,该值越大则更倾向于产生不同的内容 - "request_timeout": conf().get( - "request_timeout", None - ), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间 + "frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 + "presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 + "request_timeout": conf().get("request_timeout", None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间 "timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试 "stop": ["\n\n\n"], } @@ -71,17 +63,13 @@ def reply(self, query, context=None): result["content"], ) logger.debug( - "[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format( - str(session), session_id, reply_content, completion_tokens - ) + "[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(str(session), session_id, reply_content, completion_tokens) ) if total_tokens == 0: reply = Reply(ReplyType.ERROR, reply_content) else: - self.sessions.session_reply( - reply_content, session_id, total_tokens - ) + self.sessions.session_reply(reply_content, session_id, total_tokens) reply = Reply(ReplyType.TEXT, reply_content) return reply elif context.type == ContextType.IMAGE_CREATE: @@ -96,9 +84,7 @@ def reply(self, query, context=None): def reply_text(self, session: OpenAISession, retry_count=0): try: response = openai.Completion.create(prompt=str(session), **self.args) - res_content = ( - response.choices[0]["text"].strip().replace("<|endoftext|>", "") - ) + res_content = response.choices[0]["text"].strip().replace("<|endoftext|>", "") total_tokens = response["usage"]["total_tokens"] completion_tokens = response["usage"]["completion_tokens"] logger.info("[OPEN_AI] reply={}".format(res_content)) diff --git a/bot/openai/open_ai_image.py b/bot/openai/open_ai_image.py index 5dbbd23ed..b188557f3 100644 --- a/bot/openai/open_ai_image.py +++ b/bot/openai/open_ai_image.py @@ -23,9 +23,7 @@ def create_img(self, query, retry_count=0): response = openai.Image.create( prompt=query, # 图片描述 n=1, # 每次生成图片的数量 - size=conf().get( - "image_create_size", "256x256" - ), # 图片大小,可选有 256x256, 512x512, 1024x1024 + size=conf().get("image_create_size", "256x256"), # 图片大小,可选有 256x256, 512x512, 1024x1024 ) image_url = response["data"][0]["url"] logger.info("[OPEN_AI] image_url={}".format(image_url)) @@ -34,11 +32,7 @@ def create_img(self, query, retry_count=0): logger.warn(e) if retry_count < 1: time.sleep(5) - logger.warn( - "[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format( - retry_count + 1 - ) - ) + logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count + 1)) return self.create_img(query, retry_count + 1) else: return False, "提问太快啦,请休息一下再问我吧" diff --git a/bot/openai/open_ai_session.py b/bot/openai/open_ai_session.py index 78cf43900..8f6aa4f5b 100644 --- a/bot/openai/open_ai_session.py +++ b/bot/openai/open_ai_session.py @@ -36,9 +36,7 @@ def discard_exceeding(self, max_tokens, cur_tokens=None): precise = False if cur_tokens is None: raise e - logger.debug( - "Exception when counting tokens precisely for query: {}".format(e) - ) + logger.debug("Exception when counting tokens precisely for query: {}".format(e)) while cur_tokens > max_tokens: if len(self.messages) > 1: self.messages.pop(0) @@ -50,18 +48,10 @@ def discard_exceeding(self, max_tokens, cur_tokens=None): cur_tokens = len(str(self)) break elif len(self.messages) == 1 and self.messages[0]["role"] == "user": - logger.warn( - "user question exceed max_tokens. total_tokens={}".format( - cur_tokens - ) - ) + logger.warn("user question exceed max_tokens. total_tokens={}".format(cur_tokens)) break else: - logger.debug( - "max_tokens={}, total_tokens={}, len(conversation)={}".format( - max_tokens, cur_tokens, len(self.messages) - ) - ) + logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.messages))) break if precise: cur_tokens = self.calc_tokens() diff --git a/bot/session_manager.py b/bot/session_manager.py index 1aff647b5..8d70886e0 100644 --- a/bot/session_manager.py +++ b/bot/session_manager.py @@ -55,9 +55,7 @@ def build_session(self, session_id, system_prompt=None): return self.sessioncls(session_id, system_prompt, **self.session_args) if session_id not in self.sessions: - self.sessions[session_id] = self.sessioncls( - session_id, system_prompt, **self.session_args - ) + self.sessions[session_id] = self.sessioncls(session_id, system_prompt, **self.session_args) elif system_prompt is not None: # 如果有新的system_prompt,更新并重置session self.sessions[session_id].set_system_prompt(system_prompt) session = self.sessions[session_id] @@ -71,9 +69,7 @@ def session_query(self, query, session_id): total_tokens = session.discard_exceeding(max_tokens, None) logger.debug("prompt tokens used={}".format(total_tokens)) except Exception as e: - logger.debug( - "Exception when counting tokens precisely for prompt: {}".format(str(e)) - ) + logger.debug("Exception when counting tokens precisely for prompt: {}".format(str(e))) return session def session_reply(self, reply, session_id, total_tokens=None): @@ -82,17 +78,9 @@ def session_reply(self, reply, session_id, total_tokens=None): try: max_tokens = conf().get("conversation_max_tokens", 1000) tokens_cnt = session.discard_exceeding(max_tokens, total_tokens) - logger.debug( - "raw total_tokens={}, savesession tokens={}".format( - total_tokens, tokens_cnt - ) - ) + logger.debug("raw total_tokens={}, savesession tokens={}".format(total_tokens, tokens_cnt)) except Exception as e: - logger.debug( - "Exception when counting tokens precisely for session: {}".format( - str(e) - ) - ) + logger.debug("Exception when counting tokens precisely for session: {}".format(str(e))) return session def clear_session(self, session_id): diff --git a/bridge/bridge.py b/bridge/bridge.py index dcf6e7e07..78fe23458 100644 --- a/bridge/bridge.py +++ b/bridge/bridge.py @@ -1,11 +1,12 @@ -from bot import bot_factory +from bot.bot_factory import create_bot from bridge.context import Context from bridge.reply import Reply from common import const from common.log import logger from common.singleton import singleton from config import conf -from voice import voice_factory +from translate.factory import create_translator +from voice.factory import create_voice @singleton @@ -15,6 +16,7 @@ def __init__(self): "chat": const.CHATGPT, "voice_to_text": conf().get("voice_to_text", "openai"), "text_to_voice": conf().get("text_to_voice", "google"), + "translate": conf().get("translate", "baidu"), } model_type = conf().get("model") if model_type in ["text-davinci-003"]: @@ -27,11 +29,13 @@ def get_bot(self, typename): if self.bots.get(typename) is None: logger.info("create bot {} for {}".format(self.btype[typename], typename)) if typename == "text_to_voice": - self.bots[typename] = voice_factory.create_voice(self.btype[typename]) + self.bots[typename] = create_voice(self.btype[typename]) elif typename == "voice_to_text": - self.bots[typename] = voice_factory.create_voice(self.btype[typename]) + self.bots[typename] = create_voice(self.btype[typename]) elif typename == "chat": - self.bots[typename] = bot_factory.create_bot(self.btype[typename]) + self.bots[typename] = create_bot(self.btype[typename]) + elif typename == "translate": + self.bots[typename] = create_translator(self.btype[typename]) return self.bots[typename] def get_bot_type(self, typename): @@ -45,3 +49,6 @@ def fetch_voice_to_text(self, voiceFile) -> Reply: def fetch_text_to_voice(self, text) -> Reply: return self.get_bot("text_to_voice").textToVoice(text) + + def fetch_translate(self, text, from_lang="", to_lang="en") -> Reply: + return self.get_bot("translate").translate(text, from_lang, to_lang) diff --git a/bridge/context.py b/bridge/context.py index c1eb10c1f..ab004c0f6 100644 --- a/bridge/context.py +++ b/bridge/context.py @@ -60,6 +60,4 @@ def __delitem__(self, key): del self.kwargs[key] def __str__(self): - return "Context(type={}, content={}, kwargs={})".format( - self.type, self.content, self.kwargs - ) + return "Context(type={}, content={}, kwargs={})".format(self.type, self.content, self.kwargs) diff --git a/channel/chat_channel.py b/channel/chat_channel.py index 89dfe145a..d57cdb569 100644 --- a/channel/chat_channel.py +++ b/channel/chat_channel.py @@ -53,9 +53,7 @@ def _compose_context(self, ctype: ContextType, content, **kwargs): group_id = cmsg.other_user_id group_name_white_list = config.get("group_name_white_list", []) - group_name_keyword_white_list = config.get( - "group_name_keyword_white_list", [] - ) + group_name_keyword_white_list = config.get("group_name_keyword_white_list", []) if any( [ group_name in group_name_white_list, @@ -63,9 +61,7 @@ def _compose_context(self, ctype: ContextType, content, **kwargs): check_contain(group_name, group_name_keyword_white_list), ] ): - group_chat_in_one_session = conf().get( - "group_chat_in_one_session", [] - ) + group_chat_in_one_session = conf().get("group_chat_in_one_session", []) session_id = cmsg.actual_user_id if any( [ @@ -81,17 +77,11 @@ def _compose_context(self, ctype: ContextType, content, **kwargs): else: context["session_id"] = cmsg.other_user_id context["receiver"] = cmsg.other_user_id - e_context = PluginManager().emit_event( - EventContext( - Event.ON_RECEIVE_MESSAGE, {"channel": self, "context": context} - ) - ) + e_context = PluginManager().emit_event(EventContext(Event.ON_RECEIVE_MESSAGE, {"channel": self, "context": context})) context = e_context["context"] if e_context.is_pass() or context is None: return context - if cmsg.from_user_id == self.user_id and not config.get( - "trigger_by_self", True - ): + if cmsg.from_user_id == self.user_id and not config.get("trigger_by_self", True): logger.debug("[WX]self message skipped") return None @@ -114,28 +104,22 @@ def _compose_context(self, ctype: ContextType, content, **kwargs): logger.info("[WX]receive group at") if not conf().get("group_at_off", False): flag = True - pattern = f"@{self.name}(\u2005|\u0020)" + pattern = f"@{re.escape(self.name)}(\u2005|\u0020)" content = re.sub(pattern, r"", content) if not flag: if context["origin_ctype"] == ContextType.VOICE: - logger.info( - "[WX]receive group voice, but checkprefix didn't match" - ) + logger.info("[WX]receive group voice, but checkprefix didn't match") return None else: # 单聊 - match_prefix = check_prefix( - content, conf().get("single_chat_prefix", [""]) - ) + match_prefix = check_prefix(content, conf().get("single_chat_prefix", [""])) if match_prefix is not None: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容 content = content.replace(match_prefix, "", 1).strip() - elif ( - context["origin_ctype"] == ContextType.VOICE - ): # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件 + elif context["origin_ctype"] == ContextType.VOICE: # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件 pass else: return None - + content = content.strip() img_match_prefix = check_prefix(content, conf().get("image_create_prefix")) if img_match_prefix: content = content.replace(img_match_prefix, "", 1) @@ -143,18 +127,10 @@ def _compose_context(self, ctype: ContextType, content, **kwargs): else: context.type = ContextType.TEXT context.content = content.strip() - if ( - "desire_rtype" not in context - and conf().get("always_reply_voice") - and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE - ): + if "desire_rtype" not in context and conf().get("always_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE: context["desire_rtype"] = ReplyType.VOICE elif context.type == ContextType.VOICE: - if ( - "desire_rtype" not in context - and conf().get("voice_reply_voice") - and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE - ): + if "desire_rtype" not in context and conf().get("voice_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE: context["desire_rtype"] = ReplyType.VOICE return context @@ -182,15 +158,8 @@ def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply: ) reply = e_context["reply"] if not e_context.is_pass(): - logger.debug( - "[WX] ready to handle context: type={}, content={}".format( - context.type, context.content - ) - ) - if ( - context.type == ContextType.TEXT - or context.type == ContextType.IMAGE_CREATE - ): # 文字和图片消息 + logger.debug("[WX] ready to handle context: type={}, content={}".format(context.type, context.content)) + if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息 reply = super().build_reply_content(context.content, context) elif context.type == ContextType.VOICE: # 语音消息 cmsg = context["msg"] @@ -214,9 +183,7 @@ def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply: # logger.warning("[WX]delete temp file error: " + str(e)) if reply.type == ReplyType.TEXT: - new_context = self._compose_context( - ContextType.TEXT, reply.content, **context.kwargs - ) + new_context = self._compose_context(ContextType.TEXT, reply.content, **context.kwargs) if new_context: reply = self._generate_reply(new_context) else: @@ -246,48 +213,24 @@ def _decorate_reply(self, context: Context, reply: Reply) -> Reply: if reply.type == ReplyType.TEXT: reply_text = reply.content - if ( - desire_rtype == ReplyType.VOICE - and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE - ): + if desire_rtype == ReplyType.VOICE and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE: reply = super().build_text_to_voice(reply.content) return self._decorate_reply(context, reply) if context.get("isgroup", False): - reply_text = ( - "@" - + context["msg"].actual_user_nickname - + " " - + reply_text.strip() - ) - reply_text = ( - conf().get("group_chat_reply_prefix", "") + reply_text - ) + reply_text = "@" + context["msg"].actual_user_nickname + " " + reply_text.strip() + reply_text = conf().get("group_chat_reply_prefix", "") + reply_text else: - reply_text = ( - conf().get("single_chat_reply_prefix", "") + reply_text - ) + reply_text = conf().get("single_chat_reply_prefix", "") + reply_text reply.content = reply_text elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO: reply.content = "[" + str(reply.type) + "]\n" + reply.content - elif ( - reply.type == ReplyType.IMAGE_URL - or reply.type == ReplyType.VOICE - or reply.type == ReplyType.IMAGE - ): + elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE: pass else: logger.error("[WX] unknown reply type: {}".format(reply.type)) return - if ( - desire_rtype - and desire_rtype != reply.type - and reply.type not in [ReplyType.ERROR, ReplyType.INFO] - ): - logger.warning( - "[WX] desire_rtype: {}, but reply type: {}".format( - context.get("desire_rtype"), reply.type - ) - ) + if desire_rtype and desire_rtype != reply.type and reply.type not in [ReplyType.ERROR, ReplyType.INFO]: + logger.warning("[WX] desire_rtype: {}, but reply type: {}".format(context.get("desire_rtype"), reply.type)) return reply def _send_reply(self, context: Context, reply: Reply): @@ -300,9 +243,7 @@ def _send_reply(self, context: Context, reply: Reply): ) reply = e_context["reply"] if not e_context.is_pass() and reply and reply.type: - logger.debug( - "[WX] ready to send reply: {}, context: {}".format(reply, context) - ) + logger.debug("[WX] ready to send reply: {}, context: {}".format(reply, context)) self._send(reply, context) def _send(self, reply: Reply, context: Context, retry_cnt=0): @@ -328,9 +269,7 @@ def func(worker: Future): try: worker_exception = worker.exception() if worker_exception: - self._fail_callback( - session_id, exception=worker_exception, **kwargs - ) + self._fail_callback(session_id, exception=worker_exception, **kwargs) else: self._success_callback(session_id, **kwargs) except CancelledError as e: @@ -366,24 +305,14 @@ def consume(self): if not context_queue.empty(): context = context_queue.get() logger.debug("[WX] consume context: {}".format(context)) - future: Future = self.handler_pool.submit( - self._handle, context - ) - future.add_done_callback( - self._thread_pool_callback(session_id, context=context) - ) + future: Future = self.handler_pool.submit(self._handle, context) + future.add_done_callback(self._thread_pool_callback(session_id, context=context)) if session_id not in self.futures: self.futures[session_id] = [] self.futures[session_id].append(future) - elif ( - semaphore._initial_value == semaphore._value + 1 - ): # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕 - self.futures[session_id] = [ - t for t in self.futures[session_id] if not t.done() - ] - assert ( - len(self.futures[session_id]) == 0 - ), "thread pool error" + elif semaphore._initial_value == semaphore._value + 1: # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕 + self.futures[session_id] = [t for t in self.futures[session_id] if not t.done()] + assert len(self.futures[session_id]) == 0, "thread pool error" del self.sessions[session_id] else: semaphore.release() @@ -397,9 +326,7 @@ def cancel_session(self, session_id): future.cancel() cnt = self.sessions[session_id][0].qsize() if cnt > 0: - logger.info( - "Cancel {} messages in session {}".format(cnt, session_id) - ) + logger.info("Cancel {} messages in session {}".format(cnt, session_id)) self.sessions[session_id][0] = Dequeue() def cancel_all_session(self): @@ -409,9 +336,7 @@ def cancel_all_session(self): future.cancel() cnt = self.sessions[session_id][0].qsize() if cnt > 0: - logger.info( - "Cancel {} messages in session {}".format(cnt, session_id) - ) + logger.info("Cancel {} messages in session {}".format(cnt, session_id)) self.sessions[session_id][0] = Dequeue() diff --git a/channel/terminal/terminal_channel.py b/channel/terminal/terminal_channel.py index e2060789c..9a413dcff 100644 --- a/channel/terminal/terminal_channel.py +++ b/channel/terminal/terminal_channel.py @@ -77,9 +77,7 @@ def startup(self): if check_prefix(prompt, trigger_prefixs) is None: prompt = trigger_prefixs[0] + prompt # 给没触发的消息加上触发前缀 - context = self._compose_context( - ContextType.TEXT, prompt, msg=TerminalMessage(msg_id, prompt) - ) + context = self._compose_context(ContextType.TEXT, prompt, msg=TerminalMessage(msg_id, prompt)) if context: self.produce(context) else: diff --git a/channel/wechat/wechat_channel.py b/channel/wechat/wechat_channel.py index cf200b17e..16d788c44 100644 --- a/channel/wechat/wechat_channel.py +++ b/channel/wechat/wechat_channel.py @@ -56,10 +56,7 @@ def wrapper(self, cmsg: ChatMessage): return self.receivedMsgs[msgId] = cmsg create_time = cmsg.create_time # 消息时间戳 - if ( - conf().get("hot_reload") == True - and int(create_time) < int(time.time()) - 60 - ): # 跳过1分钟前的历史消息 + if conf().get("hot_reload") == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息 logger.debug("[WX]history message {} skipped".format(msgId)) return return func(self, cmsg) @@ -88,15 +85,9 @@ def qrCallback(uuid, status, qrcode): url = f"https://login.weixin.qq.com/l/{uuid}" qr_api1 = "https://api.isoyu.com/qr/?m=1&e=L&p=20&url={}".format(url) - qr_api2 = ( - "https://api.qrserver.com/v1/create-qr-code/?size=400×400&data={}".format( - url - ) - ) + qr_api2 = "https://api.qrserver.com/v1/create-qr-code/?size=400×400&data={}".format(url) qr_api3 = "https://api.pwmqr.com/qrcode/create/?url={}".format(url) - qr_api4 = "https://my.tv.sohu.com/user/a/wvideo/getQRCode.do?text={}".format( - url - ) + qr_api4 = "https://my.tv.sohu.com/user/a/wvideo/getQRCode.do?text={}".format(url) print("You can also scan QRCode in any website below:") print(qr_api3) print(qr_api4) @@ -134,18 +125,12 @@ def startup(self): logger.error("Hot reload failed, try to login without hot reload") itchat.logout() os.remove(status_path) - itchat.auto_login( - enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback - ) + itchat.auto_login(enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback) else: raise e self.user_id = itchat.instance.storageClass.userName self.name = itchat.instance.storageClass.nickName - logger.info( - "Wechat login success, user_id: {}, nickname: {}".format( - self.user_id, self.name - ) - ) + logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name)) # start message listener itchat.run() @@ -173,16 +158,10 @@ def handle_single(self, cmsg: ChatMessage): elif cmsg.ctype == ContextType.PATPAT: logger.debug("[WX]receive patpat msg: {}".format(cmsg.content)) elif cmsg.ctype == ContextType.TEXT: - logger.debug( - "[WX]receive text msg: {}, cmsg={}".format( - json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg - ) - ) + logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg)) else: logger.debug("[WX]receive msg: {}, cmsg={}".format(cmsg.content, cmsg)) - context = self._compose_context( - cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg - ) + context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg) if context: self.produce(context) @@ -202,9 +181,7 @@ def handle_group(self, cmsg: ChatMessage): pass else: logger.debug("[WX]receive group msg: {}".format(cmsg.content)) - context = self._compose_context( - cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg - ) + context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg) if context: self.produce(context) diff --git a/channel/wechat/wechat_message.py b/channel/wechat/wechat_message.py index 18884259c..63c225471 100644 --- a/channel/wechat/wechat_message.py +++ b/channel/wechat/wechat_message.py @@ -27,37 +27,23 @@ def __init__(self, itchat_msg, is_group=False): self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径 self._prepare_fn = lambda: itchat_msg.download(self.content) elif itchat_msg["Type"] == NOTE and itchat_msg["MsgType"] == 10000: - if is_group and ( - "加入群聊" in itchat_msg["Content"] or "加入了群聊" in itchat_msg["Content"] - ): + if is_group and ("加入群聊" in itchat_msg["Content"] or "加入了群聊" in itchat_msg["Content"]): self.ctype = ContextType.JOIN_GROUP self.content = itchat_msg["Content"] # 这里只能得到nickname, actual_user_id还是机器人的id if "加入了群聊" in itchat_msg["Content"]: - self.actual_user_nickname = re.findall( - r"\"(.*?)\"", itchat_msg["Content"] - )[-1] + self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[-1] elif "加入群聊" in itchat_msg["Content"]: - self.actual_user_nickname = re.findall( - r"\"(.*?)\"", itchat_msg["Content"] - )[0] + self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0] elif "拍了拍我" in itchat_msg["Content"]: self.ctype = ContextType.PATPAT self.content = itchat_msg["Content"] if is_group: - self.actual_user_nickname = re.findall( - r"\"(.*?)\"", itchat_msg["Content"] - )[0] + self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0] else: - raise NotImplementedError( - "Unsupported note message: " + itchat_msg["Content"] - ) + raise NotImplementedError("Unsupported note message: " + itchat_msg["Content"]) else: - raise NotImplementedError( - "Unsupported message type: Type:{} MsgType:{}".format( - itchat_msg["Type"], itchat_msg["MsgType"] - ) - ) + raise NotImplementedError("Unsupported message type: Type:{} MsgType:{}".format(itchat_msg["Type"], itchat_msg["MsgType"])) self.from_user_id = itchat_msg["FromUserName"] self.to_user_id = itchat_msg["ToUserName"] diff --git a/channel/wechat/wechaty_channel.py b/channel/wechat/wechaty_channel.py index 7383a206c..051a9cf10 100644 --- a/channel/wechat/wechaty_channel.py +++ b/channel/wechat/wechaty_channel.py @@ -60,13 +60,9 @@ def send(self, reply: Reply, context: Context): receiver_id = context["receiver"] loop = asyncio.get_event_loop() if context["isgroup"]: - receiver = asyncio.run_coroutine_threadsafe( - self.bot.Room.find(receiver_id), loop - ).result() + receiver = asyncio.run_coroutine_threadsafe(self.bot.Room.find(receiver_id), loop).result() else: - receiver = asyncio.run_coroutine_threadsafe( - self.bot.Contact.find(receiver_id), loop - ).result() + receiver = asyncio.run_coroutine_threadsafe(self.bot.Contact.find(receiver_id), loop).result() msg = None if reply.type == ReplyType.TEXT: msg = reply.content @@ -83,9 +79,7 @@ def send(self, reply: Reply, context: Context): voiceLength = int(any_to_sil(file_path, sil_file)) if voiceLength >= 60000: voiceLength = 60000 - logger.info( - "[WX] voice too long, length={}, set to 60s".format(voiceLength) - ) + logger.info("[WX] voice too long, length={}, set to 60s".format(voiceLength)) # 发送语音 t = int(time.time()) msg = FileBox.from_file(sil_file, name=str(t) + ".sil") @@ -98,9 +92,7 @@ def send(self, reply: Reply, context: Context): os.remove(sil_file) except Exception as e: pass - logger.info( - "[WX] sendVoice={}, receiver={}".format(reply.content, receiver) - ) + logger.info("[WX] sendVoice={}, receiver={}".format(reply.content, receiver)) elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片 img_url = reply.content t = int(time.time()) @@ -111,9 +103,7 @@ def send(self, reply: Reply, context: Context): image_storage = reply.content image_storage.seek(0) t = int(time.time()) - msg = FileBox.from_base64( - base64.b64encode(image_storage.read()), str(t) + ".png" - ) + msg = FileBox.from_base64(base64.b64encode(image_storage.read()), str(t) + ".png") asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result() logger.info("[WX] sendImage, receiver={}".format(receiver)) diff --git a/channel/wechat/wechaty_message.py b/channel/wechat/wechaty_message.py index f7d27faf8..cdb41ddf2 100644 --- a/channel/wechat/wechaty_message.py +++ b/channel/wechat/wechaty_message.py @@ -45,16 +45,12 @@ async def __init__(self, wechaty_msg: Message): def func(): loop = asyncio.get_event_loop() - asyncio.run_coroutine_threadsafe( - voice_file.to_file(self.content), loop - ).result() + asyncio.run_coroutine_threadsafe(voice_file.to_file(self.content), loop).result() self._prepare_fn = func else: - raise NotImplementedError( - "Unsupported message type: {}".format(wechaty_msg.type()) - ) + raise NotImplementedError("Unsupported message type: {}".format(wechaty_msg.type())) from_contact = wechaty_msg.talker() # 获取消息的发送者 self.from_user_id = from_contact.contact_id @@ -73,9 +69,7 @@ def func(): self.to_user_id = to_contact.contact_id self.to_user_nickname = to_contact.name - if ( - self.is_group or wechaty_msg.is_self() - ): # 如果是群消息,other_user设置为群,如果是私聊消息,而且自己发的,就设置成对方。 + if self.is_group or wechaty_msg.is_self(): # 如果是群消息,other_user设置为群,如果是私聊消息,而且自己发的,就设置成对方。 self.other_user_id = self.to_user_id self.other_user_nickname = self.to_user_nickname else: @@ -86,7 +80,7 @@ def func(): self.is_at = await wechaty_msg.mention_self() if not self.is_at: # 有时候复制粘贴的消息,不算做@,但是内容里面会有@xxx,这里做一下兼容 name = wechaty_msg.wechaty.user_self().name - pattern = f"@{name}(\u2005|\u0020)" + pattern = f"@{re.escape(name)}(\u2005|\u0020)" if re.search(pattern, self.content): logger.debug(f"wechaty message {self.msg_id} include at") self.is_at = True diff --git a/channel/wechatmp/README.md b/channel/wechatmp/README.md index 219d27675..98ff769c9 100644 --- a/channel/wechatmp/README.md +++ b/channel/wechatmp/README.md @@ -1,21 +1,24 @@ # 微信公众号channel 鉴于个人微信号在服务器上通过itchat登录有封号风险,这里新增了微信公众号channel,提供无风险的服务。 -目前支持订阅号和服务号两种类型的公众号。个人主体的微信订阅号由于无法通过微信认证,接口存在限制,目前仅支持最基本的文本交互和语音输入。通过微信认证的订阅号或者服务号可以回复图片和语音。 + +目前支持订阅号和服务号两种类型的公众号,它们都支持文本交互,语音和图片输入。其中个人主体的微信订阅号由于无法通过微信认证,存在回复时间限制,每天的图片和声音回复次数也有限制。 + ## 使用方法(订阅号,服务号类似) 在开始部署前,你需要一个拥有公网IP的服务器,以提供微信服务器和我们自己服务器的连接。或者你需要进行内网穿透,否则微信服务器无法将消息发送给我们的服务器。 -此外,需要在我们的服务器上安装python的web框架web.py。 +此外,需要在我们的服务器上安装python的web框架web.py和wechatpy。 以ubuntu为例(在ubuntu 22.04上测试): ``` pip3 install web.py +pip3 install wechatpy ``` 然后在[微信公众平台](https://mp.weixin.qq.com)注册一个自己的公众号,类型选择订阅号,主体为个人即可。 -然后根据[接入指南](https://developers.weixin.qq.com/doc/offiaccount/Basic_Information/Access_Overview.html)的说明,在[微信公众平台](https://mp.weixin.qq.com)的“设置与开发”-“基本配置”-“服务器配置”中填写服务器地址`URL`和令牌`Token`。这里的`URL`是`example.com/wx`的形式,不可以使用IP,`Token`是你自己编的一个特定的令牌。消息加解密方式目前选择的是明文模式。 +然后根据[接入指南](https://developers.weixin.qq.com/doc/offiaccount/Basic_Information/Access_Overview.html)的说明,在[微信公众平台](https://mp.weixin.qq.com)的“设置与开发”-“基本配置”-“服务器配置”中填写服务器地址`URL`和令牌`Token`。这里的`URL`是`example.com/wx`的形式,不可以使用IP,`Token`是你自己编的一个特定的令牌。消息加解密方式如果选择了需要加密的模式,需要在配置中填写`wechatmp_aes_key`。 相关的服务器验证代码已经写好,你不需要再添加任何代码。你只需要在本项目根目录的`config.json`中添加 ``` @@ -24,6 +27,7 @@ pip3 install web.py "wechatmp_port": 8080, # 微信公众平台的端口,需要端口转发到80或443 "wechatmp_app_id": "xxxx", # 微信公众平台的appID "wechatmp_app_secret": "xxxx", # 微信公众平台的appsecret +"wechatmp_aes_key": "", # 微信公众平台的EncodingAESKey,加密模式需要 "single_chat_prefix": [""], # 推荐设置,任意对话都可以触发回复,不添加前缀 "single_chat_reply_prefix": "", # 推荐设置,回复不设置前缀 "plugin_trigger_prefix": "&", # 推荐设置,在手机微信客户端中,$%^等符号与中文连在一起时会自动显示一段较大的间隔,用户体验不好。请不要使用管理员指令前缀"#",这会造成未知问题。 @@ -40,12 +44,14 @@ sudo iptables-save > /etc/iptables/rules.v4 程序启动并监听端口后,在刚才的“服务器配置”中点击`提交`即可验证你的服务器。 随后在[微信公众平台](https://mp.weixin.qq.com)启用服务器,关闭手动填写规则的自动回复,即可实现ChatGPT的自动回复。 -如果在启用后如果遇到如下报错: +之后需要在公众号开发信息下将本机IP加入到IP白名单。 + +不然在启用后,发送语音、图片等消息可能会遇到如下报错: ``` 'errcode': 40164, 'errmsg': 'invalid ip xx.xx.xx.xx not in whitelist rid ``` -需要在公众号开发信息下将IP加入到IP白名单。 + ## 个人微信公众号的限制 由于人微信公众号不能通过微信认证,所以没有客服接口,因此公众号无法主动发出消息,只能被动回复。而微信官方对被动回复有5秒的时间限制,最多重试2次,因此最多只有15秒的自动回复时间窗口。因此如果问题比较复杂或者我们的服务器比较忙,ChatGPT的回答就没办法及时回复给用户。为了解决这个问题,这里做了回答缓存,它需要你在回复超时后,再次主动发送任意文字(例如1)来尝试拿到回答缓存。为了优化使用体验,目前设置了两分钟(120秒)的timeout,用户在至多两分钟后即可得到查询到回复或者错误原因。 @@ -91,7 +97,7 @@ python3 -m pip install pyttsx3 ## TODO - [x] 语音输入 - - [ ] 图片输入 + - [x] 图片输入 - [x] 使用临时素材接口提供认证公众号的图片和语音回复 - [x] 使用永久素材接口提供未认证公众号的图片和语音回复 - [ ] 高并发支持 diff --git a/channel/wechatmp/active_reply.py b/channel/wechatmp/active_reply.py index d8a8ddee1..37356959b 100644 --- a/channel/wechatmp/active_reply.py +++ b/channel/wechatmp/active_reply.py @@ -1,13 +1,16 @@ import time import web +from wechatpy import parse_message +from wechatpy.replies import create_reply + -from channel.wechatmp.wechatmp_message import parse_xml -from channel.wechatmp.passive_reply_message import TextMsg from bridge.context import * -from bridge.reply import ReplyType +from bridge.reply import * from channel.wechatmp.common import * from channel.wechatmp.wechatmp_channel import WechatMPChannel +from channel.wechatmp.wechatmp_message import WeChatMPMessage + from common.log import logger from config import conf @@ -19,18 +22,25 @@ def GET(self): def POST(self): # Make sure to return the instance that first created, @singleton will do that. - channel = WechatMPChannel() try: - webData = web.data() - # logger.debug("[wechatmp] Receive request:\n" + webData.decode("utf-8")) - wechatmp_msg = parse_xml(webData) - if ( - wechatmp_msg.msg_type == "text" - or wechatmp_msg.msg_type == "voice" - # or wechatmp_msg.msg_type == "image" - ): + args = web.input() + verify_server(args) + channel = WechatMPChannel() + message = web.data() + encrypt_func = lambda x: x + if args.get("encrypt_type") == "aes": + logger.debug("[wechatmp] Receive encrypted post data:\n" + message.decode("utf-8")) + if not channel.crypto: + raise Exception("Crypto not initialized, Please set wechatmp_aes_key in config.json") + message = channel.crypto.decrypt_message(message, args.msg_signature, args.timestamp, args.nonce) + encrypt_func = lambda x: channel.crypto.encrypt_message(x, args.nonce, args.timestamp) + else: + logger.debug("[wechatmp] Receive post data:\n" + message.decode("utf-8")) + msg = parse_message(message) + if msg.type in ["text", "voice", "image"]: + wechatmp_msg = WeChatMPMessage(msg, client=channel.client) from_user = wechatmp_msg.from_user_id - message = wechatmp_msg.content + content = wechatmp_msg.content message_id = wechatmp_msg.msg_id logger.info( @@ -39,38 +49,29 @@ def POST(self): web.ctx.env.get("REMOTE_PORT"), from_user, message_id, - message, + content, ) ) - if (wechatmp_msg.msg_type == "voice" and conf().get("voice_reply_voice") == True): - rtype = ReplyType.VOICE + if msg.type == "voice" and wechatmp_msg.ctype == ContextType.TEXT and conf().get("voice_reply_voice", False): + context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, desire_rtype=ReplyType.VOICE, msg=wechatmp_msg) else: - rtype = None - context = channel._compose_context( - ContextType.TEXT, message, isgroup=False, desire_rtype=rtype, msg=wechatmp_msg - ) + context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, msg=wechatmp_msg) if context: # set private openai_api_key # if from_user is not changed in itchat, this can be placed at chat_channel user_data = conf().get_user_data(from_user) - context["openai_api_key"] = user_data.get( - "openai_api_key" - ) # None or user openai_api_key + context["openai_api_key"] = user_data.get("openai_api_key") # None or user openai_api_key channel.produce(context) # The reply will be sent by channel.send() in another thread return "success" - - elif wechatmp_msg.msg_type == "event": - logger.info( - "[wechatmp] Event {} from {}".format( - wechatmp_msg.Event, wechatmp_msg.from_user_id - ) - ) - content = subscribe_msg() - replyMsg = TextMsg( - wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content - ) - return replyMsg.send() + elif msg.type == "event": + logger.info("[wechatmp] Event {} from {}".format(msg.event, msg.source)) + if msg.event in ["subscribe", "subscribe_scan"]: + reply_text = subscribe_msg() + replyPost = create_reply(reply_text, msg) + return encrypt_func(replyPost.render()) + else: + return "success" else: logger.info("暂且不处理") return "success" diff --git a/channel/wechatmp/common.py b/channel/wechatmp/common.py index 5efccfce1..b6f206c5a 100644 --- a/channel/wechatmp/common.py +++ b/channel/wechatmp/common.py @@ -1,6 +1,10 @@ -import hashlib import textwrap +import web +from wechatpy.crypto import WeChatCrypto +from wechatpy.exceptions import InvalidSignatureException +from wechatpy.utils import check_signature + from config import conf MAX_UTF8_LEN = 2048 @@ -12,38 +16,28 @@ class WeChatAPIException(Exception): def verify_server(data): try: - if len(data) == 0: - return "None" signature = data.signature timestamp = data.timestamp nonce = data.nonce - echostr = data.echostr + echostr = data.get("echostr", None) token = conf().get("wechatmp_token") # 请按照公众平台官网\基本配置中信息填写 - - data_list = [token, timestamp, nonce] - data_list.sort() - sha1 = hashlib.sha1() - # map(sha1.update, data_list) #python2 - sha1.update("".join(data_list).encode("utf-8")) - hashcode = sha1.hexdigest() - print("handle/GET func: hashcode, signature: ", hashcode, signature) - if hashcode == signature: - return echostr - else: - return "" - except Exception as Argument: - return Argument + check_signature(token, signature, timestamp, nonce) + return echostr + except InvalidSignatureException: + raise web.Forbidden("Invalid signature") + except Exception as e: + raise web.Forbidden(str(e)) def subscribe_msg(): - trigger_prefix = conf().get("single_chat_prefix", [""]) + trigger_prefix = conf().get("single_chat_prefix", [""])[0] msg = textwrap.dedent( f"""\ 感谢您的关注! 这里是ChatGPT,可以自由对话。 资源有限,回复较慢,请勿着急。 支持语音对话。 - 暂时不支持图片输入。 + 支持图片输入。 支持图片输出,画字开头的消息将按要求创作图片。 支持tool、角色扮演和文字冒险等丰富的插件。 输入'{trigger_prefix}#帮助' 查看详细指令。""" @@ -59,7 +53,7 @@ def split_string_by_utf8_length(string, max_length, max_split=0): if max_split > 0 and len(result) >= max_split: result.append(encoded[start:].decode("utf-8")) break - end = start + max_length + end = min(start + max_length, len(encoded)) # 如果当前字节不是 UTF-8 编码的开始字节,则向前查找直到找到开始字节为止 while end < len(encoded) and (encoded[end] & 0b11000000) == 0b10000000: end -= 1 diff --git a/channel/wechatmp/passive_reply.py b/channel/wechatmp/passive_reply.py index eca94ba36..6f3fb0a72 100644 --- a/channel/wechatmp/passive_reply.py +++ b/channel/wechatmp/passive_reply.py @@ -1,14 +1,16 @@ -import time + import asyncio +import time import web +from wechatpy import parse_message +from wechatpy.replies import ImageReply, VoiceReply, create_reply -from channel.wechatmp.wechatmp_message import parse_xml -from channel.wechatmp.passive_reply_message import TextMsg, VoiceMsg, ImageMsg from bridge.context import * -from bridge.reply import ReplyType +from bridge.reply import * from channel.wechatmp.common import * from channel.wechatmp.wechatmp_channel import WechatMPChannel +from channel.wechatmp.wechatmp_message import WeChatMPMessage from common.log import logger from config import conf @@ -20,39 +22,44 @@ def GET(self): def POST(self): try: + args = web.input() + verify_server(args) request_time = time.time() channel = WechatMPChannel() - webData = web.data() - logger.debug("[wechatmp] Receive post data:\n" + webData.decode("utf-8")) - wechatmp_msg = parse_xml(webData) - if wechatmp_msg.msg_type == "text" or wechatmp_msg.msg_type == "voice": + message = web.data() + encrypt_func = lambda x: x + if args.get("encrypt_type") == "aes": + logger.debug("[wechatmp] Receive encrypted post data:\n" + message.decode("utf-8")) + if not channel.crypto: + raise Exception("Crypto not initialized, Please set wechatmp_aes_key in config.json") + message = channel.crypto.decrypt_message(message, args.msg_signature, args.timestamp, args.nonce) + encrypt_func = lambda x: channel.crypto.encrypt_message(x, args.nonce, args.timestamp) + else: + logger.debug("[wechatmp] Receive post data:\n" + message.decode("utf-8")) + msg = parse_message(message) + if msg.type in ["text", "voice", "image"]: + wechatmp_msg = WeChatMPMessage(msg, client=channel.client) from_user = wechatmp_msg.from_user_id - to_user = wechatmp_msg.to_user_id - message = wechatmp_msg.content + content = wechatmp_msg.content message_id = wechatmp_msg.msg_id supported = True - if "【收到不支持的消息类型,暂无法显示】" in message: + if "【收到不支持的消息类型,暂无法显示】" in content: supported = False # not supported, used to refresh # New request if ( from_user not in channel.cache_dict and from_user not in channel.running - or message.startswith("#") - and message_id not in channel.request_cnt # insert the godcmd + or content.startswith("#") + and message_id not in channel.request_cnt # insert the godcmd ): # The first query begin - if (wechatmp_msg.msg_type == "voice" and conf().get("voice_reply_voice") == True): - rtype = ReplyType.VOICE + if msg.type == "voice" and wechatmp_msg.ctype == ContextType.TEXT and conf().get("voice_reply_voice", False): + context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, desire_rtype=ReplyType.VOICE, msg=wechatmp_msg) else: - rtype = None - context = channel._compose_context( - ContextType.TEXT, message, isgroup=False, desire_rtype=rtype, msg=wechatmp_msg - ) - logger.debug( - "[wechatmp] context: {} {}".format(context, wechatmp_msg) - ) + context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, msg=wechatmp_msg) + logger.debug("[wechatmp] context: {} {} {}".format(context, wechatmp_msg, supported)) if supported and context: # set private openai_api_key @@ -62,43 +69,38 @@ def POST(self): channel.running.add(from_user) channel.produce(context) else: - trigger_prefix = conf().get("single_chat_prefix", [""]) + trigger_prefix = conf().get("single_chat_prefix", [""])[0] if trigger_prefix or not supported: if trigger_prefix: - content = textwrap.dedent( + reply_text = textwrap.dedent( f"""\ 请输入'{trigger_prefix}'接你想说的话跟我说话。 例如: {trigger_prefix}你好,很高兴见到你。""" ) else: - content = textwrap.dedent( + reply_text = textwrap.dedent( """\ 你好,很高兴见到你。 请跟我说话吧。""" ) else: logger.error(f"[wechatmp] unknown error") - content = textwrap.dedent( + reply_text = textwrap.dedent( """\ 未知错误,请稍后再试""" ) - replyPost = TextMsg(wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content).send() - return replyPost + replyPost = create_reply(reply_text, msg) + return encrypt_func(replyPost.render()) # Wechat official server will request 3 times (5 seconds each), with the same message_id. # Because the interval is 5 seconds, here assumed that do not have multithreading problems. request_cnt = channel.request_cnt.get(message_id, 0) + 1 channel.request_cnt[message_id] = request_cnt logger.info( - "[wechatmp] Request {} from {} {}\n{}\n{}:{}".format( - request_cnt, - from_user, - message_id, - message, - web.ctx.env.get("REMOTE_ADDR"), - web.ctx.env.get("REMOTE_PORT"), + "[wechatmp] Request {} from {} {} {}:{}\n{}".format( + request_cnt, from_user, message_id, web.ctx.env.get("REMOTE_ADDR"), web.ctx.env.get("REMOTE_PORT"), content ) ) @@ -118,76 +120,91 @@ def POST(self): time.sleep(2) # and do nothing, waiting for the next request return "success" - else: # request_cnt == 3: + else: # request_cnt == 3: # return timeout message reply_text = "【正在思考中,回复任意文字尝试获取回复】" - replyPost = TextMsg(from_user, to_user, reply_text).send() - return replyPost + replyPost = create_reply(reply_text, msg) + return encrypt_func(replyPost.render()) # reply is ready channel.request_cnt.pop(message_id) # no return because of bandwords or other reasons - if ( - from_user not in channel.cache_dict - and from_user not in channel.running - ): + if from_user not in channel.cache_dict and from_user not in channel.running: return "success" # Only one request can access to the cached data try: - (reply_type, content) = channel.cache_dict.pop(from_user) + (reply_type, reply_content) = channel.cache_dict.pop(from_user) except KeyError: return "success" - if (reply_type == "text"): - if len(content.encode("utf8")) <= MAX_UTF8_LEN: - reply_text = content + if reply_type == "text": + if len(reply_content.encode("utf8")) <= MAX_UTF8_LEN: + reply_text = reply_content else: continue_text = "\n【未完待续,回复任意文字以继续】" splits = split_string_by_utf8_length( - content, + reply_content, MAX_UTF8_LEN - len(continue_text.encode("utf-8")), max_split=1, ) reply_text = splits[0] + continue_text channel.cache_dict[from_user] = ("text", splits[1]) - + logger.info( "[wechatmp] Request {} do send to {} {}: {}\n{}".format( request_cnt, from_user, message_id, - message, + content, reply_text, ) ) - replyPost = TextMsg(from_user, to_user, reply_text).send() - return replyPost + replyPost = create_reply(reply_text, msg) + return encrypt_func(replyPost.render()) - elif (reply_type == "voice"): - media_id = content + elif reply_type == "voice": + media_id = reply_content asyncio.run_coroutine_threadsafe(channel.delete_media(media_id), channel.delete_media_loop) - replyPost = VoiceMsg(from_user, to_user, media_id).send() - return replyPost + logger.info( + "[wechatmp] Request {} do send to {} {}: {} voice media_id {}".format( + request_cnt, + from_user, + message_id, + content, + media_id, + ) + ) + replyPost = VoiceReply(message=msg) + replyPost.media_id = media_id + return encrypt_func(replyPost.render()) - elif (reply_type == "image"): - media_id = content + elif reply_type == "image": + media_id = reply_content asyncio.run_coroutine_threadsafe(channel.delete_media(media_id), channel.delete_media_loop) - replyPost = ImageMsg(from_user, to_user, media_id).send() - return replyPost - - elif wechatmp_msg.msg_type == "event": - logger.info( - "[wechatmp] Event {} from {}".format( - wechatmp_msg.content, wechatmp_msg.from_user_id + logger.info( + "[wechatmp] Request {} do send to {} {}: {} image media_id {}".format( + request_cnt, + from_user, + message_id, + content, + media_id, + ) ) - ) - content = subscribe_msg() - replyMsg = TextMsg( - wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content - ) - return replyMsg.send() + replyPost = ImageReply(message=msg) + replyPost.media_id = media_id + return encrypt_func(replyPost.render()) + + elif msg.type == "event": + logger.info("[wechatmp] Event {} from {}".format(msg.event, msg.source)) + if msg.event in ["subscribe", "subscribe_scan"]: + reply_text = subscribe_msg() + replyPost = create_reply(reply_text, msg) + return encrypt_func(replyPost.render()) + else: + return "success" + else: logger.info("暂且不处理") return "success" diff --git a/channel/wechatmp/passive_reply_message.py b/channel/wechatmp/passive_reply_message.py deleted file mode 100644 index ef58d7093..000000000 --- a/channel/wechatmp/passive_reply_message.py +++ /dev/null @@ -1,78 +0,0 @@ -# -*- coding: utf-8 -*-# -# filename: reply.py -import time - - -class Msg(object): - def __init__(self): - pass - - def send(self): - return "success" - - -class TextMsg(Msg): - def __init__(self, toUserName, fromUserName, content): - self.__dict = dict() - self.__dict["ToUserName"] = toUserName - self.__dict["FromUserName"] = fromUserName - self.__dict["CreateTime"] = int(time.time()) - self.__dict["Content"] = content - - def send(self): - XmlForm = """ - - - - {CreateTime} - - - - """ - return XmlForm.format(**self.__dict) - - -class VoiceMsg(Msg): - def __init__(self, toUserName, fromUserName, mediaId): - self.__dict = dict() - self.__dict["ToUserName"] = toUserName - self.__dict["FromUserName"] = fromUserName - self.__dict["CreateTime"] = int(time.time()) - self.__dict["MediaId"] = mediaId - - def send(self): - XmlForm = """ - - - - {CreateTime} - - - - - - """ - return XmlForm.format(**self.__dict) - - -class ImageMsg(Msg): - def __init__(self, toUserName, fromUserName, mediaId): - self.__dict = dict() - self.__dict["ToUserName"] = toUserName - self.__dict["FromUserName"] = fromUserName - self.__dict["CreateTime"] = int(time.time()) - self.__dict["MediaId"] = mediaId - - def send(self): - XmlForm = """ - - - - {CreateTime} - - - - - - """ - return XmlForm.format(**self.__dict) diff --git a/channel/wechatmp/wechatmp_channel.py b/channel/wechatmp/wechatmp_channel.py index 9780048b9..aa1fc74d7 100644 --- a/channel/wechatmp/wechatmp_channel.py +++ b/channel/wechatmp/wechatmp_channel.py @@ -1,22 +1,26 @@ # -*- coding: utf-8 -*- +import asyncio +import imghdr import io import os +import threading import time -import imghdr + import requests +import web +from wechatpy.crypto import WeChatCrypto +from wechatpy.exceptions import WeChatClientException + from bridge.context import * from bridge.reply import * from channel.chat_channel import ChatChannel -from channel.wechatmp.wechatmp_client import WechatMPClient from channel.wechatmp.common import * +from channel.wechatmp.wechatmp_client import WechatMPClient from common.log import logger from common.singleton import singleton from config import conf +from voice.audio_convert import any_to_mp3 -import asyncio -from threading import Thread - -import web # If using SSL, uncomment the following lines, and modify the certificate path. # from cheroot.server import HTTPServer # from cheroot.ssl.builtin import BuiltinSSLAdapter @@ -31,7 +35,14 @@ def __init__(self, passive_reply=True): super().__init__() self.passive_reply = passive_reply self.NOT_SUPPORT_REPLYTYPE = [] - self.client = WechatMPClient() + appid = conf().get("wechatmp_app_id") + secret = conf().get("wechatmp_app_secret") + token = conf().get("wechatmp_token") + aes_key = conf().get("wechatmp_aes_key") + self.client = WechatMPClient(appid, secret) + self.crypto = None + if aes_key: + self.crypto = WeChatCrypto(token, aes_key, appid) if self.passive_reply: # Cache the reply to the user's first message self.cache_dict = dict() @@ -41,11 +52,10 @@ def __init__(self, passive_reply=True): self.request_cnt = dict() # The permanent media need to be deleted to avoid media number limit self.delete_media_loop = asyncio.new_event_loop() - t = Thread(target=self.start_loop, args=(self.delete_media_loop,)) + t = threading.Thread(target=self.start_loop, args=(self.delete_media_loop,)) t.setDaemon(True) t.start() - def startup(self): if self.passive_reply: urls = ("/wx", "channel.wechatmp.passive_reply.Query") @@ -62,7 +72,7 @@ def start_loop(self, loop): async def delete_media(self, media_id): logger.debug("[wechatmp] permanent media {} will be deleted in 10s".format(media_id)) await asyncio.sleep(10) - self.client.delete_permanent_media(media_id) + self.client.material.delete(media_id) logger.info("[wechatmp] permanent media {} has been deleted".format(media_id)) def send(self, reply: Reply, context: Context): @@ -70,97 +80,134 @@ def send(self, reply: Reply, context: Context): if self.passive_reply: if reply.type == ReplyType.TEXT or reply.type == ReplyType.INFO or reply.type == ReplyType.ERROR: reply_text = reply.content - logger.info("[wechatmp] reply to {} cached:\n{}".format(receiver, reply_text)) + logger.info("[wechatmp] text cached, receiver {}\n{}".format(receiver, reply_text)) self.cache_dict[receiver] = ("text", reply_text) elif reply.type == ReplyType.VOICE: - voice_file_path = reply.content - logger.info("[wechatmp] voice file path {}".format(voice_file_path)) - with open(voice_file_path, 'rb') as f: - filename = receiver + "-" + context["msg"].msg_id + ".mp3" - media_id = self.client.upload_permanent_media("voice", (filename, f, "audio/mpeg")) - # 根据文件大小估计一个微信自动审核的时间,审核结束前返回将会导致语音无法播放,这个估计有待验证 - f_size = os.fstat(f.fileno()).st_size - print(f_size) - time.sleep(1.0 + 2 * f_size / 1024 / 1024) - logger.info("[wechatmp] voice reply to {} uploaded: {}".format(receiver, media_id)) - self.cache_dict[receiver] = ("voice", media_id) + try: + voice_file_path = reply.content + with open(voice_file_path, "rb") as f: + # support: <2M, <60s, mp3/wma/wav/amr + response = self.client.material.add("voice", f) + logger.debug("[wechatmp] upload voice response: {}".format(response)) + # 根据文件大小估计一个微信自动审核的时间,审核结束前返回将会导致语音无法播放,这个估计有待验证 + f_size = os.fstat(f.fileno()).st_size + time.sleep(1.0 + 2 * f_size / 1024 / 1024) + # todo check media_id + except WeChatClientException as e: + logger.error("[wechatmp] upload voice failed: {}".format(e)) + return + media_id = response["media_id"] + logger.info("[wechatmp] voice uploaded, receiver {}, media_id {}".format(receiver, media_id)) + self.cache_dict[receiver] = ("voice", media_id) + elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片 img_url = reply.content pic_res = requests.get(img_url, stream=True) - print(pic_res.headers) image_storage = io.BytesIO() for block in pic_res.iter_content(1024): image_storage.write(block) image_storage.seek(0) image_type = imghdr.what(image_storage) - filename = receiver + "-" + context["msg"].msg_id + "." + image_type + filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type content_type = "image/" + image_type - media_id = self.client.upload_permanent_media("image", (filename, image_storage, content_type)) - logger.info("[wechatmp] image reply to {} uploaded: {}".format(receiver, media_id)) + try: + response = self.client.material.add("image", (filename, image_storage, content_type)) + logger.debug("[wechatmp] upload image response: {}".format(response)) + except WeChatClientException as e: + logger.error("[wechatmp] upload image failed: {}".format(e)) + return + media_id = response["media_id"] + logger.info("[wechatmp] image uploaded, receiver {}, media_id {}".format(receiver, media_id)) self.cache_dict[receiver] = ("image", media_id) elif reply.type == ReplyType.IMAGE: # 从文件读取图片 image_storage = reply.content image_storage.seek(0) image_type = imghdr.what(image_storage) - filename = receiver + "-" + context["msg"].msg_id + "." + image_type + filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type content_type = "image/" + image_type - media_id = self.client.upload_permanent_media("image", (filename, image_storage, content_type)) - logger.info("[wechatmp] image reply to {} uploaded: {}".format(receiver, media_id)) + try: + response = self.client.material.add("image", (filename, image_storage, content_type)) + logger.debug("[wechatmp] upload image response: {}".format(response)) + except WeChatClientException as e: + logger.error("[wechatmp] upload image failed: {}".format(e)) + return + media_id = response["media_id"] + logger.info("[wechatmp] image uploaded, receiver {}, media_id {}".format(receiver, media_id)) self.cache_dict[receiver] = ("image", media_id) else: if reply.type == ReplyType.TEXT or reply.type == ReplyType.INFO or reply.type == ReplyType.ERROR: reply_text = reply.content - self.client.send_text(receiver, reply_text) - logger.info("[wechatmp] Do send to {}: {}".format(receiver, reply_text)) + texts = split_string_by_utf8_length(reply_text, MAX_UTF8_LEN) + if len(texts) > 1: + logger.info("[wechatmp] text too long, split into {} parts".format(len(texts))) + for text in texts: + self.client.message.send_text(receiver, text) + logger.info("[wechatmp] Do send text to {}: {}".format(receiver, reply_text)) elif reply.type == ReplyType.VOICE: - voice_file_path = reply.content - logger.info("[wechatmp] voice file path {}".format(voice_file_path)) - with open(voice_file_path, 'rb') as f: - filename = receiver + "-" + context["msg"].msg_id + ".mp3" - media_id = self.client.upload_media("voice", (filename, f, "audio/mpeg")) - self.client.send_voice(receiver, media_id) - logger.info("[wechatmp] Do send voice to {}".format(receiver)) + try: + file_path = reply.content + file_name = os.path.basename(file_path) + file_type = os.path.splitext(file_name)[1] + if file_type == ".mp3": + file_type = "audio/mpeg" + elif file_type == ".amr": + file_type = "audio/amr" + else: + mp3_file = os.path.splitext(file_path)[0] + ".mp3" + any_to_mp3(file_path, mp3_file) + file_path = mp3_file + file_name = os.path.basename(file_path) + file_type = "audio/mpeg" + logger.info("[wechatmp] file_name: {}, file_type: {} ".format(file_name, file_type)) + # support: <2M, <60s, AMR\MP3 + response = self.client.media.upload("voice", (file_name, open(file_path, "rb"), file_type)) + logger.debug("[wechatmp] upload voice response: {}".format(response)) + except WeChatClientException as e: + logger.error("[wechatmp] upload voice failed: {}".format(e)) + return + self.client.message.send_voice(receiver, response["media_id"]) + logger.info("[wechatmp] Do send voice to {}".format(receiver)) elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片 img_url = reply.content pic_res = requests.get(img_url, stream=True) - print(pic_res.headers) image_storage = io.BytesIO() for block in pic_res.iter_content(1024): image_storage.write(block) image_storage.seek(0) image_type = imghdr.what(image_storage) - filename = receiver + "-" + context["msg"].msg_id + "." + image_type + filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type content_type = "image/" + image_type - # content_type = pic_res.headers.get('content-type') - media_id = self.client.upload_media("image", (filename, image_storage, content_type)) - self.client.send_image(receiver, media_id) - logger.info("[wechatmp] sendImage url={}, receiver={}".format(img_url, receiver)) + try: + response = self.client.media.upload("image", (filename, image_storage, content_type)) + logger.debug("[wechatmp] upload image response: {}".format(response)) + except WeChatClientException as e: + logger.error("[wechatmp] upload image failed: {}".format(e)) + return + self.client.message.send_image(receiver, response["media_id"]) + logger.info("[wechatmp] Do send image to {}".format(receiver)) elif reply.type == ReplyType.IMAGE: # 从文件读取图片 image_storage = reply.content image_storage.seek(0) image_type = imghdr.what(image_storage) - filename = receiver + "-" + context["msg"].msg_id + "." + image_type + filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type content_type = "image/" + image_type - media_id = self.client.upload_media("image", (filename, image_storage, content_type)) - self.client.send_image(receiver, media_id) - logger.info("[wechatmp] sendImage, receiver={}".format(receiver)) + try: + response = self.client.media.upload("image", (filename, image_storage, content_type)) + logger.debug("[wechatmp] upload image response: {}".format(response)) + except WeChatClientException as e: + logger.error("[wechatmp] upload image failed: {}".format(e)) + return + self.client.message.send_image(receiver, response["media_id"]) + logger.info("[wechatmp] Do send image to {}".format(receiver)) return def _success_callback(self, session_id, context, **kwargs): # 线程异常结束时的回调函数 - logger.debug( - "[wechatmp] Success to generate reply, msgId={}".format( - context["msg"].msg_id - ) - ) + logger.debug("[wechatmp] Success to generate reply, msgId={}".format(context["msg"].msg_id)) if self.passive_reply: self.running.remove(session_id) def _fail_callback(self, session_id, exception, context, **kwargs): # 线程异常结束时的回调函数 - logger.exception( - "[wechatmp] Fail to generate reply to user, msgId={}, exception={}".format( - context["msg"].msg_id, exception - ) - ) + logger.exception("[wechatmp] Fail to generate reply to user, msgId={}, exception={}".format(context["msg"].msg_id, exception)) if self.passive_reply: assert session_id not in self.cache_dict self.running.remove(session_id) diff --git a/channel/wechatmp/wechatmp_client.py b/channel/wechatmp/wechatmp_client.py index 96ebddb74..19dca3219 100644 --- a/channel/wechatmp/wechatmp_client.py +++ b/channel/wechatmp/wechatmp_client.py @@ -1,180 +1,49 @@ -import time -import json -import requests import threading -from channel.wechatmp.common import * -from common.log import logger -from config import conf - - -class WechatMPClient: - def __init__(self): - self.app_id = conf().get("wechatmp_app_id") - self.app_secret = conf().get("wechatmp_app_secret") - self.access_token = None - self.access_token_expires_time = 0 - self.access_token_lock = threading.Lock() - self.get_access_token() - - - def wechatmp_request(self, method, url, **kwargs): - r = requests.request(method=method, url=url, **kwargs) - r.raise_for_status() - r.encoding = "utf-8" - ret = r.json() - if "errcode" in ret and ret["errcode"] != 0: - if ret["errcode"] == 45009: - self.clear_quota_v2() - raise WeChatAPIException("{}".format(ret)) - return ret - - def get_access_token(self): - # return the access_token - if self.access_token: - if self.access_token_expires_time - time.time() > 60: - return self.access_token - - # Get new access_token - # Do not request access_token in parallel! Only the last obtained is valid. - if self.access_token_lock.acquire(blocking=False): - # Wait for other threads that have previously obtained access_token to complete the request - # This happens every 2 hours, so it doesn't affect the experience very much - time.sleep(1) - self.access_token = None - url = "https://api.weixin.qq.com/cgi-bin/token" - params = { - "grant_type": "client_credential", - "appid": self.app_id, - "secret": self.app_secret, - } - ret = self.wechatmp_request(method="get", url=url, params=params) - self.access_token = ret["access_token"] - self.access_token_expires_time = int(time.time()) + ret["expires_in"] - logger.info("[wechatmp] access_token: {}".format(self.access_token)) - self.access_token_lock.release() - else: - # Wait for token update - while self.access_token_lock.locked(): - time.sleep(0.1) - return self.access_token - - - def send_text(self, receiver, reply_text): - url = "https://api.weixin.qq.com/cgi-bin/message/custom/send" - params = {"access_token": self.get_access_token()} - json_data = { - "touser": receiver, - "msgtype": "text", - "text": {"content": reply_text}, - } - self.wechatmp_request( - method="post", - url=url, - params=params, - data=json.dumps(json_data, ensure_ascii=False).encode("utf8"), - ) - - - def send_voice(self, receiver, media_id): - url="https://api.weixin.qq.com/cgi-bin/message/custom/send" - params = {"access_token": self.get_access_token()} - json_data = { - "touser": receiver, - "msgtype": "voice", - "voice": { - "media_id": media_id - } - } - self.wechatmp_request( - method="post", - url=url, - params=params, - data=json.dumps(json_data, ensure_ascii=False).encode("utf8"), - ) - - def send_image(self, receiver, media_id): - url="https://api.weixin.qq.com/cgi-bin/message/custom/send" - params = {"access_token": self.get_access_token()} - json_data = { - "touser": receiver, - "msgtype": "image", - "image": { - "media_id": media_id - } - } - self.wechatmp_request( - method="post", - url=url, - params=params, - data=json.dumps(json_data, ensure_ascii=False).encode("utf8"), - ) - - - def upload_media(self, media_type, media_file): - url="https://api.weixin.qq.com/cgi-bin/media/upload" - params={ - "access_token": self.get_access_token(), - "type": media_type - } - files={"media": media_file} - ret = self.wechatmp_request( - method="post", - url=url, - params=params, - files=files - ) - logger.debug("[wechatmp] media {} uploaded".format(media_file)) - return ret["media_id"] +import time +from wechatpy.client import WeChatClient +from wechatpy.exceptions import APILimitedException - def upload_permanent_media(self, media_type, media_file): - url="https://api.weixin.qq.com/cgi-bin/material/add_material" - params={ - "access_token": self.get_access_token(), - "type": media_type - } - files={"media": media_file} - ret = self.wechatmp_request( - method="post", - url=url, - params=params, - files=files - ) - logger.debug("[wechatmp] permanent media {} uploaded".format(media_file)) - return ret["media_id"] +from channel.wechatmp.common import * +from common.log import logger - def delete_permanent_media(self, media_id): - url="https://api.weixin.qq.com/cgi-bin/material/del_material" - params={ - "access_token": self.get_access_token() - } - self.wechatmp_request( - method="post", - url=url, - params=params, - data=json.dumps({"media_id": media_id}, ensure_ascii=False).encode("utf8") - ) - logger.debug("[wechatmp] permanent media {} deleted".format(media_id)) +class WechatMPClient(WeChatClient): + def __init__(self, appid, secret, access_token=None, session=None, timeout=None, auto_retry=True): + super(WechatMPClient, self).__init__(appid, secret, access_token, session, timeout, auto_retry) + self.fetch_access_token_lock = threading.Lock() + self.clear_quota_lock = threading.Lock() + self.last_clear_quota_time = -1 def clear_quota(self): - url="https://api.weixin.qq.com/cgi-bin/clear_quota" - params = { - "access_token": self.get_access_token() - } - self.wechatmp_request( - method="post", - url=url, - params=params, - data={"appid": self.app_id} - ) - logger.debug("[wechatmp] API quata has been cleard") + return self.post("clear_quota", data={"appid": self.appid}) def clear_quota_v2(self): - url="https://api.weixin.qq.com/cgi-bin/clear_quota/v2" - self.wechatmp_request( - method="post", - url=url, - data={"appid": self.app_id, "appsecret": self.app_secret} - ) - logger.debug("[wechatmp] API quata has been cleard") + return self.post("clear_quota/v2", params={"appid": self.appid, "appsecret": self.secret}) + + def fetch_access_token(self): # 重载父类方法,加锁避免多线程重复获取access_token + with self.fetch_access_token_lock: + access_token = self.session.get(self.access_token_key) + if access_token: + if not self.expires_at: + return access_token + timestamp = time.time() + if self.expires_at - timestamp > 60: + return access_token + return super().fetch_access_token() + + def _request(self, method, url_or_endpoint, **kwargs): # 重载父类方法,遇到API限流时,清除quota后重试 + try: + return super()._request(method, url_or_endpoint, **kwargs) + except APILimitedException as e: + logger.error("[wechatmp] API quata has been used up. {}".format(e)) + if self.last_clear_quota_time == -1 or time.time() - self.last_clear_quota_time > 60: + with self.clear_quota_lock: + if self.last_clear_quota_time == -1 or time.time() - self.last_clear_quota_time > 60: + self.last_clear_quota_time = time.time() + response = self.clear_quota_v2() + logger.debug("[wechatmp] API quata has been cleard, {}".format(response)) + return super()._request(method, url_or_endpoint, **kwargs) + else: + logger.error("[wechatmp] last clear quota time is {}, less than 60s, skip clear quota") + raise e diff --git a/channel/wechatmp/wechatmp_message.py b/channel/wechatmp/wechatmp_message.py index d385897c3..27c9fbb85 100644 --- a/channel/wechatmp/wechatmp_message.py +++ b/channel/wechatmp/wechatmp_message.py @@ -1,50 +1,56 @@ # -*- coding: utf-8 -*-# -# filename: receive.py -import xml.etree.ElementTree as ET from bridge.context import ContextType from channel.chat_message import ChatMessage from common.log import logger - - -def parse_xml(web_data): - if len(web_data) == 0: - return None - xmlData = ET.fromstring(web_data) - return WeChatMPMessage(xmlData) +from common.tmp_dir import TmpDir class WeChatMPMessage(ChatMessage): - def __init__(self, xmlData): - super().__init__(xmlData) - self.to_user_id = xmlData.find("ToUserName").text - self.from_user_id = xmlData.find("FromUserName").text - self.create_time = xmlData.find("CreateTime").text - self.msg_type = xmlData.find("MsgType").text - try: - self.msg_id = xmlData.find("MsgId").text - except: - self.msg_id = self.from_user_id + self.create_time + def __init__(self, msg, client=None): + super().__init__(msg) + self.msg_id = msg.id + self.create_time = msg.time self.is_group = False - # reply to other_user_id - self.other_user_id = self.from_user_id - - if self.msg_type == "text": + if msg.type == "text": self.ctype = ContextType.TEXT - self.content = xmlData.find("Content").text - elif self.msg_type == "voice": - self.ctype = ContextType.TEXT - self.content = xmlData.find("Recognition").text # 接收语音识别结果 - # other voice_to_text method not implemented yet - if self.content == None: - self.content = "你好" - elif self.msg_type == "image": - # not implemented yet - self.pic_url = xmlData.find("PicUrl").text - self.media_id = xmlData.find("MediaId").text - elif self.msg_type == "event": - self.content = xmlData.find("Event").text - else: # video, shortvideo, location, link - # not implemented - pass + self.content = msg.content + elif msg.type == "voice": + if msg.recognition == None: + self.ctype = ContextType.VOICE + self.content = TmpDir().path() + msg.media_id + "." + msg.format # content直接存临时目录路径 + + def download_voice(): + # 如果响应状态码是200,则将响应内容写入本地文件 + response = client.media.download(msg.media_id) + if response.status_code == 200: + with open(self.content, "wb") as f: + f.write(response.content) + else: + logger.info(f"[wechatmp] Failed to download voice file, {response.content}") + + self._prepare_fn = download_voice + else: + self.ctype = ContextType.TEXT + self.content = msg.recognition + elif msg.type == "image": + self.ctype = ContextType.IMAGE + self.content = TmpDir().path() + msg.media_id + ".png" # content直接存临时目录路径 + + def download_image(): + # 如果响应状态码是200,则将响应内容写入本地文件 + response = client.media.download(msg.media_id) + if response.status_code == 200: + with open(self.content, "wb") as f: + f.write(response.content) + else: + logger.info(f"[wechatmp] Failed to download image file, {response.content}") + + self._prepare_fn = download_image + else: + raise NotImplementedError("Unsupported message type: Type:{} ".format(msg.type)) + + self.from_user_id = msg.source + self.to_user_id = msg.target + self.other_user_id = msg.source diff --git a/common/time_check.py b/common/time_check.py index 808f71ab3..5c2dacba6 100644 --- a/common/time_check.py +++ b/common/time_check.py @@ -13,23 +13,15 @@ def _time_checker(self, *args, **kwargs): if chat_time_module: chat_start_time = _config.get("chat_start_time", "00:00") chat_stopt_time = _config.get("chat_stop_time", "24:00") - time_regex = re.compile( - r"^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$" - ) # 时间匹配,包含24:00 + time_regex = re.compile(r"^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$") # 时间匹配,包含24:00 starttime_format_check = time_regex.match(chat_start_time) # 检查停止时间格式 stoptime_format_check = time_regex.match(chat_stopt_time) # 检查停止时间格式 chat_time_check = chat_start_time < chat_stopt_time # 确定启动时间<停止时间 # 时间格式检查 - if not ( - starttime_format_check and stoptime_format_check and chat_time_check - ): - logger.warn( - "时间格式不正确,请在config.json中修改您的CHAT_START_TIME/CHAT_STOP_TIME,否则可能会影响您正常使用,开始({})-结束({})".format( - starttime_format_check, stoptime_format_check - ) - ) + if not (starttime_format_check and stoptime_format_check and chat_time_check): + logger.warn("时间格式不正确,请在config.json中修改您的CHAT_START_TIME/CHAT_STOP_TIME,否则可能会影响您正常使用,开始({})-结束({})".format(starttime_format_check, stoptime_format_check)) if chat_start_time > "23:59": logger.error("启动时间可能存在问题,请修改!") diff --git a/config.py b/config.py index 2c6739b64..0ba5aa64c 100644 --- a/config.py +++ b/config.py @@ -68,6 +68,11 @@ "chat_time_module": False, # 是否开启服务时间限制 "chat_start_time": "00:00", # 服务开始时间 "chat_stop_time": "24:00", # 服务结束时间 + # 翻译api + "translate": "baidu", # 翻译api,支持baidu + # baidu翻译api的配置 + "baidu_translate_app_id": "", # 百度翻译api的appid + "baidu_translate_app_key": "", # 百度翻译api的秘钥 # itchat的配置 "hot_reload": False, # 是否开启热重载 # wechaty的配置 @@ -75,8 +80,9 @@ # wechatmp的配置 "wechatmp_token": "", # 微信公众平台的Token "wechatmp_port": 8080, # 微信公众平台的端口,需要端口转发到80或443 - "wechatmp_app_id": "", # 微信公众平台的appID,仅服务号需要 - "wechatmp_app_secret": "", # 微信公众平台的appsecret,仅服务号需要 + "wechatmp_app_id": "", # 微信公众平台的appID + "wechatmp_app_secret": "", # 微信公众平台的appsecret + "wechatmp_aes_key": "", # 微信公众平台的EncodingAESKey,加密模式需要 # chatgpt指令自定义触发词 "clear_memory_commands": ["#清除记忆"], # 重置会话指令,必须以#开头 # channel配置 @@ -159,9 +165,7 @@ def load_config(): for name, value in os.environ.items(): name = name.lower() if name in available_setting: - logger.info( - "[INIT] override config by environ args: {}={}".format(name, value) - ) + logger.info("[INIT] override config by environ args: {}={}".format(name, value)) try: config[name] = eval(value) except: diff --git a/docker/Dockerfile.alpine b/docker/Dockerfile.alpine index 324a76ebf..61f80c28f 100644 --- a/docker/Dockerfile.alpine +++ b/docker/Dockerfile.alpine @@ -22,8 +22,8 @@ RUN apk add --no-cache \ && cd ${BUILD_PREFIX} \ && cp config-template.json ${BUILD_PREFIX}/config.json \ && /usr/local/bin/python -m pip install --no-cache --upgrade pip \ - && pip install --no-cache -r requirements.txt \ - && pip install --no-cache -r requirements-optional.txt \ + && pip install --no-cache -r requirements.txt --extra-index-url https://alpine-wheels.github.io/index\ + && pip install --no-cache -r requirements-optional.txt --extra-index-url https://alpine-wheels.github.io/index\ && apk del curl wget WORKDIR ${BUILD_PREFIX} diff --git a/docker/Dockerfile.latest b/docker/Dockerfile.latest index c9a5a55e6..432075cd1 100644 --- a/docker/Dockerfile.latest +++ b/docker/Dockerfile.latest @@ -13,8 +13,8 @@ RUN apk add --no-cache bash ffmpeg espeak \ && cd ${BUILD_PREFIX} \ && cp config-template.json config.json \ && /usr/local/bin/python -m pip install --no-cache --upgrade pip \ - && pip install --no-cache -r requirements.txt \ - && pip install --no-cache -r requirements-optional.txt + && pip install --no-cache -r requirements.txt --extra-index-url https://alpine-wheels.github.io/index\ + && pip install --no-cache -r requirements-optional.txt --extra-index-url https://alpine-wheels.github.io/index WORKDIR ${BUILD_PREFIX} diff --git a/plugins/banwords/banwords.py b/plugins/banwords/banwords.py index 4f7f75cde..118b9631c 100644 --- a/plugins/banwords/banwords.py +++ b/plugins/banwords/banwords.py @@ -50,9 +50,7 @@ def __init__(self): self.reply_action = conf.get("reply_action", "ignore") logger.info("[Banwords] inited") except Exception as e: - logger.warn( - "[Banwords] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/banwords ." - ) + logger.warn("[Banwords] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/banwords .") raise e def on_handle_context(self, e_context: EventContext): @@ -72,9 +70,7 @@ def on_handle_context(self, e_context: EventContext): return elif self.action == "replace": if self.searchr.ContainsAny(content): - reply = Reply( - ReplyType.INFO, "发言中包含敏感词,请重试: \n" + self.searchr.Replace(content) - ) + reply = Reply(ReplyType.INFO, "发言中包含敏感词,请重试: \n" + self.searchr.Replace(content)) e_context["reply"] = reply e_context.action = EventAction.BREAK_PASS return @@ -94,9 +90,7 @@ def on_decorate_reply(self, e_context: EventContext): return elif self.reply_action == "replace": if self.searchr.ContainsAny(content): - reply = Reply( - ReplyType.INFO, "已替换回复中的敏感词: \n" + self.searchr.Replace(content) - ) + reply = Reply(ReplyType.INFO, "已替换回复中的敏感词: \n" + self.searchr.Replace(content)) e_context["reply"] = reply e_context.action = EventAction.CONTINUE return diff --git a/plugins/bdunit/bdunit.py b/plugins/bdunit/bdunit.py index 53aaec270..c99f2b045 100644 --- a/plugins/bdunit/bdunit.py +++ b/plugins/bdunit/bdunit.py @@ -82,9 +82,7 @@ def get_token(self): Returns: string: access_token """ - url = "https://aip.baidubce.com/oauth/2.0/token?client_id={}&client_secret={}&grant_type=client_credentials".format( - self.api_key, self.secret_key - ) + url = "https://aip.baidubce.com/oauth/2.0/token?client_id={}&client_secret={}&grant_type=client_credentials".format(self.api_key, self.secret_key) payload = "" headers = {"Content-Type": "application/json", "Accept": "application/json"} @@ -100,10 +98,7 @@ def getUnit(self, query): :returns: UNIT 解析结果。如果解析失败,返回 None """ - url = ( - "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" - + self.access_token - ) + url = "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" + self.access_token request = { "query": query, "user_id": str(get_mac())[:32], @@ -130,10 +125,7 @@ def getUnit2(self, query): :param query: 用户的指令字符串 :returns: UNIT 解析结果。如果解析失败,返回 None """ - url = ( - "https://aip.baidubce.com/rpc/2.0/unit/service/chat?access_token=" - + self.access_token - ) + url = "https://aip.baidubce.com/rpc/2.0/unit/service/chat?access_token=" + self.access_token request = {"query": query, "user_id": str(get_mac())[:32]} body = { "log_id": str(uuid.uuid1()), @@ -176,11 +168,7 @@ def hasIntent(self, parsed, intent): if parsed and "result" in parsed and "response_list" in parsed["result"]: response_list = parsed["result"]["response_list"] for response in response_list: - if ( - "schema" in response - and "intent" in response["schema"] - and response["schema"]["intent"] == intent - ): + if "schema" in response and "intent" in response["schema"] and response["schema"]["intent"] == intent: return True return False else: @@ -204,12 +192,7 @@ def getSlots(self, parsed, intent=""): logger.warning(e) return [] for response in response_list: - if ( - "schema" in response - and "intent" in response["schema"] - and "slots" in response["schema"] - and response["schema"]["intent"] == intent - ): + if "schema" in response and "intent" in response["schema"] and "slots" in response["schema"] and response["schema"]["intent"] == intent: return response["schema"]["slots"] return [] else: @@ -245,11 +228,7 @@ def getSayByConfidence(self, parsed): if ( "schema" in response and "intent_confidence" in response["schema"] - and ( - not answer - or response["schema"]["intent_confidence"] - > answer["schema"]["intent_confidence"] - ) + and (not answer or response["schema"]["intent_confidence"] > answer["schema"]["intent_confidence"]) ): answer = response return answer["action_list"][0]["say"] @@ -273,11 +252,7 @@ def getSay(self, parsed, intent=""): logger.warning(e) return "" for response in response_list: - if ( - "schema" in response - and "intent" in response["schema"] - and response["schema"]["intent"] == intent - ): + if "schema" in response and "intent" in response["schema"] and response["schema"]["intent"] == intent: try: return response["action_list"][0]["say"] except Exception as e: diff --git a/plugins/dungeon/dungeon.py b/plugins/dungeon/dungeon.py index 2e3cdf1ad..5b129d606 100644 --- a/plugins/dungeon/dungeon.py +++ b/plugins/dungeon/dungeon.py @@ -84,9 +84,7 @@ def on_handle_context(self, e_context: EventContext): if len(clist) > 1: story = clist[1] else: - story = ( - "你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。" - ) + story = "你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。" self.games[sessionid] = StoryTeller(bot, sessionid, story) reply = Reply(ReplyType.INFO, "冒险开始,你可以输入任意内容,让故事继续下去。故事背景是:" + story) e_context["reply"] = reply @@ -102,11 +100,7 @@ def get_help_text(self, **kwargs): if kwargs.get("verbose") != True: return help_text trigger_prefix = conf().get("plugin_trigger_prefix", "$") - help_text = ( - f"{trigger_prefix}开始冒险 " - + "背景故事: 开始一个基于{背景故事}的文字冒险,之后你的所有消息会协助完善这个故事。\n" - + f"{trigger_prefix}停止冒险: 结束游戏。\n" - ) + help_text = f"{trigger_prefix}开始冒险 " + "背景故事: 开始一个基于{背景故事}的文字冒险,之后你的所有消息会协助完善这个故事。\n" + f"{trigger_prefix}停止冒险: 结束游戏。\n" if kwargs.get("verbose") == True: help_text += f"\n命令例子: '{trigger_prefix}开始冒险 你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。'" return help_text diff --git a/plugins/godcmd/godcmd.py b/plugins/godcmd/godcmd.py index 9f0b3a583..97595e4d4 100644 --- a/plugins/godcmd/godcmd.py +++ b/plugins/godcmd/godcmd.py @@ -140,9 +140,7 @@ def get_help_text(isadmin, isgroup): if plugins[plugin].enabled and not plugins[plugin].hidden: namecn = plugins[plugin].namecn help_text += "\n%s:" % namecn - help_text += ( - PluginManager().instances[plugin].get_help_text(verbose=False).strip() - ) + help_text += PluginManager().instances[plugin].get_help_text(verbose=False).strip() if ADMIN_COMMANDS and isadmin: help_text += "\n\n管理员指令:\n" @@ -191,9 +189,7 @@ def __init__(self): COMMANDS["reset"]["alias"].append(custom_command) self.password = gconf["password"] - self.admin_users = gconf[ - "admin_users" - ] # 预存的管理员账号,这些账号不需要认证。itchat的用户名每次都会变,不可用 + self.admin_users = gconf["admin_users"] # 预存的管理员账号,这些账号不需要认证。itchat的用户名每次都会变,不可用 self.isrunning = True # 机器人是否运行中 self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context @@ -215,7 +211,7 @@ def on_handle_context(self, e_context: EventContext): reply.content = f"空指令,输入#help查看指令列表\n" e_context["reply"] = reply e_context.action = EventAction.BREAK_PASS - return + return # msg = e_context['context']['msg'] channel = e_context["channel"] user = e_context["context"]["receiver"] @@ -248,11 +244,7 @@ def on_handle_context(self, e_context: EventContext): if not plugincls.enabled: continue if query_name == name or query_name == plugincls.namecn: - ok, result = True, PluginManager().instances[ - name - ].get_help_text( - isgroup=isgroup, isadmin=isadmin, verbose=True - ) + ok, result = True, PluginManager().instances[name].get_help_text(isgroup=isgroup, isadmin=isadmin, verbose=True) break if not ok: result = "插件不存在或未启用" @@ -285,11 +277,7 @@ def on_handle_context(self, e_context: EventContext): if isgroup: ok, result = False, "群聊不可执行管理员指令" else: - cmd = next( - c - for c, info in ADMIN_COMMANDS.items() - if cmd in info["alias"] - ) + cmd = next(c for c, info in ADMIN_COMMANDS.items() if cmd in info["alias"]) if cmd == "stop": self.isrunning = False ok, result = True, "服务已暂停" @@ -325,18 +313,14 @@ def on_handle_context(self, e_context: EventContext): PluginManager().activate_plugins() if len(new_plugins) > 0: result += "\n发现新插件:\n" - result += "\n".join( - [f"{p.name}_v{p.version}" for p in new_plugins] - ) + result += "\n".join([f"{p.name}_v{p.version}" for p in new_plugins]) else: result += ", 未发现新插件" elif cmd == "setpri": if len(args) != 2: ok, result = False, "请提供插件名和优先级" else: - ok = PluginManager().set_plugin_priority( - args[0], int(args[1]) - ) + ok = PluginManager().set_plugin_priority(args[0], int(args[1])) if ok: result = "插件" + args[0] + "优先级已设置为" + args[1] else: diff --git a/plugins/hello/hello.py b/plugins/hello/hello.py index 254b17254..fc8fe7051 100644 --- a/plugins/hello/hello.py +++ b/plugins/hello/hello.py @@ -33,9 +33,7 @@ def on_handle_context(self, e_context: EventContext): if e_context["context"].type == ContextType.JOIN_GROUP: e_context["context"].type = ContextType.TEXT msg: ChatMessage = e_context["context"]["msg"] - e_context[ - "context" - ].content = f'请你随机使用一种风格说一句问候语来欢迎新用户"{msg.actual_user_nickname}"加入群聊。' + e_context["context"].content = f'请你随机使用一种风格说一句问候语来欢迎新用户"{msg.actual_user_nickname}"加入群聊。' e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 return @@ -53,9 +51,7 @@ def on_handle_context(self, e_context: EventContext): reply.type = ReplyType.TEXT msg: ChatMessage = e_context["context"]["msg"] if e_context["context"]["isgroup"]: - reply.content = ( - f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}" - ) + reply.content = f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}" else: reply.content = f"Hello, {msg.from_user_nickname}" e_context["reply"] = reply diff --git a/plugins/keyword/config.json.template b/plugins/keyword/config.json.template index 9a8332f3e..dbd5efe34 100644 --- a/plugins/keyword/config.json.template +++ b/plugins/keyword/config.json.template @@ -1,5 +1,5 @@ { "keyword": { - "关键字匹配": "测试成功" + "关键字匹配": "测试成功" } -} \ No newline at end of file +} diff --git a/plugins/keyword/keyword.py b/plugins/keyword/keyword.py index 376f748e9..97ebe26ac 100644 --- a/plugins/keyword/keyword.py +++ b/plugins/keyword/keyword.py @@ -41,9 +41,7 @@ def __init__(self): self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context logger.info("[keyword] inited.") except Exception as e: - logger.warn( - "[keyword] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/keyword ." - ) + logger.warn("[keyword] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/keyword .") raise e def on_handle_context(self, e_context: EventContext): diff --git a/plugins/plugin_manager.py b/plugins/plugin_manager.py index b014e5f06..d2ee75e39 100644 --- a/plugins/plugin_manager.py +++ b/plugins/plugin_manager.py @@ -31,23 +31,14 @@ def wrapper(plugincls): plugincls.desc = kwargs.get("desc") plugincls.author = kwargs.get("author") plugincls.path = self.current_plugin_path - plugincls.version = ( - kwargs.get("version") if kwargs.get("version") != None else "1.0" - ) - plugincls.namecn = ( - kwargs.get("namecn") if kwargs.get("namecn") != None else name - ) - plugincls.hidden = ( - kwargs.get("hidden") if kwargs.get("hidden") != None else False - ) + plugincls.version = kwargs.get("version") if kwargs.get("version") != None else "1.0" + plugincls.namecn = kwargs.get("namecn") if kwargs.get("namecn") != None else name + plugincls.hidden = kwargs.get("hidden") if kwargs.get("hidden") != None else False plugincls.enabled = True if self.current_plugin_path == None: raise Exception("Plugin path not set") self.plugins[name.upper()] = plugincls - logger.info( - "Plugin %s_v%s registered, path=%s" - % (name, plugincls.version, plugincls.path) - ) + logger.info("Plugin %s_v%s registered, path=%s" % (name, plugincls.version, plugincls.path)) return wrapper @@ -62,9 +53,7 @@ def load_config(self): if os.path.exists("./plugins/plugins.json"): with open("./plugins/plugins.json", "r", encoding="utf-8") as f: pconf = json.load(f) - pconf["plugins"] = SortedDict( - lambda k, v: v["priority"], pconf["plugins"], reverse=True - ) + pconf["plugins"] = SortedDict(lambda k, v: v["priority"], pconf["plugins"], reverse=True) else: modified = True pconf = {"plugins": SortedDict(lambda k, v: v["priority"], reverse=True)} @@ -90,26 +79,16 @@ def scan_plugins(self): if plugin_path in self.loaded: if self.loaded[plugin_path] == None: logger.info("reload module %s" % plugin_name) - self.loaded[plugin_path] = importlib.reload( - sys.modules[import_path] - ) - dependent_module_names = [ - name - for name in sys.modules.keys() - if name.startswith(import_path + ".") - ] + self.loaded[plugin_path] = importlib.reload(sys.modules[import_path]) + dependent_module_names = [name for name in sys.modules.keys() if name.startswith(import_path + ".")] for name in dependent_module_names: logger.info("reload module %s" % name) importlib.reload(sys.modules[name]) else: - self.loaded[plugin_path] = importlib.import_module( - import_path - ) + self.loaded[plugin_path] = importlib.import_module(import_path) self.current_plugin_path = None except Exception as e: - logger.exception( - "Failed to import plugin %s: %s" % (plugin_name, e) - ) + logger.exception("Failed to import plugin %s: %s" % (plugin_name, e)) continue pconf = self.pconf news = [self.plugins[name] for name in self.plugins] @@ -119,9 +98,7 @@ def scan_plugins(self): rawname = plugincls.name if rawname not in pconf["plugins"]: modified = True - logger.info( - "Plugin %s not found in pconfig, adding to pconfig..." % name - ) + logger.info("Plugin %s not found in pconfig, adding to pconfig..." % name) pconf["plugins"][rawname] = { "enabled": plugincls.enabled, "priority": plugincls.priority, @@ -136,9 +113,7 @@ def scan_plugins(self): def refresh_order(self): for event in self.listening_plugins.keys(): - self.listening_plugins[event].sort( - key=lambda name: self.plugins[name].priority, reverse=True - ) + self.listening_plugins[event].sort(key=lambda name: self.plugins[name].priority, reverse=True) def activate_plugins(self): # 生成新开启的插件实例 failed_plugins = [] @@ -184,13 +159,8 @@ def load_plugins(self): def emit_event(self, e_context: EventContext, *args, **kwargs): if e_context.event in self.listening_plugins: for name in self.listening_plugins[e_context.event]: - if ( - self.plugins[name].enabled - and e_context.action == EventAction.CONTINUE - ): - logger.debug( - "Plugin %s triggered by event %s" % (name, e_context.event) - ) + if self.plugins[name].enabled and e_context.action == EventAction.CONTINUE: + logger.debug("Plugin %s triggered by event %s" % (name, e_context.event)) instance = self.instances[name] instance.handlers[e_context.event](e_context, *args, **kwargs) return e_context @@ -262,9 +232,7 @@ def install_plugin(self, repo: str): source = json.load(f) if repo in source["repo"]: repo = source["repo"][repo]["url"] - match = re.match( - r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo - ) + match = re.match(r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo) if not match: return False, "安装插件失败,source中的仓库地址不合法" else: diff --git a/plugins/role/role.py b/plugins/role/role.py index 9788cc1cd..69c523347 100644 --- a/plugins/role/role.py +++ b/plugins/role/role.py @@ -69,13 +69,9 @@ def __init__(self): logger.info("[Role] inited") except Exception as e: if isinstance(e, FileNotFoundError): - logger.warn( - f"[Role] init failed, {config_path} not found, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role ." - ) + logger.warn(f"[Role] init failed, {config_path} not found, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role .") else: - logger.warn( - "[Role] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role ." - ) + logger.warn("[Role] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role .") raise e def get_role(self, name, find_closest=True, min_sim=0.35): @@ -143,9 +139,7 @@ def on_handle_context(self, e_context: EventContext): else: help_text = f"未知角色类型。\n" help_text += "目前的角色类型有: \n" - help_text += ( - ",".join([self.tags[tag][0] for tag in self.tags]) + "\n" - ) + help_text += ",".join([self.tags[tag][0] for tag in self.tags]) + "\n" else: help_text = f"请输入角色类型。\n" help_text += "目前的角色类型有: \n" @@ -158,9 +152,7 @@ def on_handle_context(self, e_context: EventContext): return logger.debug("[Role] on_handle_context. content: %s" % content) if desckey is not None: - if len(clist) == 1 or ( - len(clist) > 1 and clist[1].lower() in ["help", "帮助"] - ): + if len(clist) == 1 or (len(clist) > 1 and clist[1].lower() in ["help", "帮助"]): reply = Reply(ReplyType.INFO, self.get_help_text(verbose=True)) e_context["reply"] = reply e_context.action = EventAction.BREAK_PASS @@ -178,9 +170,7 @@ def on_handle_context(self, e_context: EventContext): self.roles[role][desckey], self.roles[role].get("wrapper", "%s"), ) - reply = Reply( - ReplyType.INFO, f"预设角色为 {role}:\n" + self.roles[role][desckey] - ) + reply = Reply(ReplyType.INFO, f"预设角色为 {role}:\n" + self.roles[role][desckey]) e_context["reply"] = reply e_context.action = EventAction.BREAK_PASS elif customize == True: @@ -199,17 +189,10 @@ def get_help_text(self, verbose=False, **kwargs): if not verbose: return help_text trigger_prefix = conf().get("plugin_trigger_prefix", "$") - help_text = ( - f"使用方法:\n{trigger_prefix}角色" - + " 预设角色名: 设定角色为{预设角色名}。\n" - + f"{trigger_prefix}role" - + " 预设角色名: 同上,但使用英文设定。\n" - ) + help_text = f"使用方法:\n{trigger_prefix}角色" + " 预设角色名: 设定角色为{预设角色名}。\n" + f"{trigger_prefix}role" + " 预设角色名: 同上,但使用英文设定。\n" help_text += f"{trigger_prefix}设定扮演" + " 角色设定: 设定自定义角色人设为{角色设定}。\n" help_text += f"{trigger_prefix}停止扮演: 清除设定的角色。\n" - help_text += ( - f"{trigger_prefix}角色类型" + " 角色类型: 查看某类{角色类型}的所有预设角色,为所有时输出所有预设角色。\n" - ) + help_text += f"{trigger_prefix}角色类型" + " 角色类型: 查看某类{角色类型}的所有预设角色,为所有时输出所有预设角色。\n" help_text += "\n目前的角色类型有: \n" help_text += ",".join([self.tags[tag][0] for tag in self.tags]) + "。\n" help_text += f"\n命令例子: \n{trigger_prefix}角色 写作助理\n" diff --git a/plugins/tool/README.md b/plugins/tool/README.md index 3ce92da47..9f91b288a 100644 --- a/plugins/tool/README.md +++ b/plugins/tool/README.md @@ -60,7 +60,7 @@ > 该tool每天返回内容相同 -#### 6.3. finance-news +#### 6.3. finance-news ###### 获取实时的金融财政新闻 > 该工具需要解决browser tool 的google-chrome依赖安装 diff --git a/plugins/tool/tool.py b/plugins/tool/tool.py index 2fb8a12bd..a3a3e76c6 100644 --- a/plugins/tool/tool.py +++ b/plugins/tool/tool.py @@ -82,9 +82,7 @@ def on_handle_context(self, e_context: EventContext): return elif content_list[1].startswith("reset"): logger.debug("[tool]: remind") - e_context[ - "context" - ].content = "请你随机用一种聊天风格,提醒用户:如果想重置tool插件,reset之后不要加任何字符" + e_context["context"].content = "请你随机用一种聊天风格,提醒用户:如果想重置tool插件,reset之后不要加任何字符" e_context.action = EventAction.BREAK return @@ -93,18 +91,14 @@ def on_handle_context(self, e_context: EventContext): # Don't modify bot name all_sessions = Bridge().get_bot("chat").sessions - user_session = all_sessions.session_query( - query, e_context["context"]["session_id"] - ).messages + user_session = all_sessions.session_query(query, e_context["context"]["session_id"]).messages # chatgpt-tool-hub will reply you with many tools logger.debug("[tool]: just-go") try: _reply = self.app.ask(query, user_session) e_context.action = EventAction.BREAK_PASS - all_sessions.session_reply( - _reply, e_context["context"]["session_id"] - ) + all_sessions.session_reply(_reply, e_context["context"]["session_id"]) except Exception as e: logger.exception(e) logger.error(str(e)) @@ -213,4 +207,4 @@ def _reset_app(self) -> App: # filter not support tool tool_list = self._filter_tool_list(tool_config.get("tools", [])) - return app.create_app(tools_list=tool_list, **app_kwargs) \ No newline at end of file + return app.create_app(tools_list=tool_list, **app_kwargs) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..abdab5733 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,8 @@ +[tool.black] +line-length = 176 +target-version = ['py37'] +include = '\.pyi?$' +extend-exclude = '.+/(dist|.venv|venv|build|lib)/.+' + +[tool.isort] +profile = "black" \ No newline at end of file diff --git a/requirements-optional.txt b/requirements-optional.txt index cfb52c953..0bcd1965d 100644 --- a/requirements-optional.txt +++ b/requirements-optional.txt @@ -7,6 +7,8 @@ gTTS>=2.3.1 # google text to speech pyttsx3>=2.90 # pytsx text to speech baidu_aip>=4.16.10 # baidu voice # azure-cognitiveservices-speech # azure voice +numpy<=1.24.2 +langid # language detect #install plugin dulwich @@ -18,6 +20,7 @@ pysilk_mod>=1.6.0 # needed by send voice # wechatmp web.py +wechatpy # chatgpt-tool-hub plugin diff --git a/requirements.txt b/requirements.txt index 2980e2ded..e7cf147fd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,3 +17,6 @@ lxml==4.9.2 pre-commit webdriver_manager selenium>=3.13.0 + +Pillow +pre-commit diff --git a/translate/baidu/baidu_translate.py b/translate/baidu/baidu_translate.py new file mode 100644 index 000000000..bf0a72143 --- /dev/null +++ b/translate/baidu/baidu_translate.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- + +import random +from hashlib import md5 + +import requests + +from config import conf +from translate.translator import Translator + + +class BaiduTranslator(Translator): + def __init__(self) -> None: + super().__init__() + endpoint = "http://api.fanyi.baidu.com" + path = "/api/trans/vip/translate" + self.url = endpoint + path + self.appid = conf().get("baidu_translate_app_id") + self.appkey = conf().get("baidu_translate_app_key") + + # For list of language codes, please refer to `https://api.fanyi.baidu.com/doc/21`, need to convert to ISO 639-1 codes + def translate(self, query: str, from_lang: str = "", to_lang: str = "en") -> str: + if not from_lang: + from_lang = "auto" # baidu suppport auto detect + salt = random.randint(32768, 65536) + sign = self.make_md5(self.appid + query + str(salt) + self.appkey) + headers = {"Content-Type": "application/x-www-form-urlencoded"} + payload = {"appid": self.appid, "q": query, "from": from_lang, "to": to_lang, "salt": salt, "sign": sign} + + retry_cnt = 3 + while retry_cnt: + r = requests.post(self.url, params=payload, headers=headers) + result = r.json() + if errcode := result.get("error_code", "52000") != "52000": + if errcode == "52001" or errcode == "52002": + retry_cnt -= 1 + continue + else: + raise Exception(result["error_msg"]) + else: + break + text = "\n".join([item["dst"] for item in result["trans_result"]]) + return text + + def make_md5(self, s, encoding="utf-8"): + return md5(s.encode(encoding)).hexdigest() diff --git a/translate/factory.py b/translate/factory.py new file mode 100644 index 000000000..ba80aa59d --- /dev/null +++ b/translate/factory.py @@ -0,0 +1,6 @@ +def create_translator(voice_type): + if voice_type == "baidu": + from translate.baidu.baidu_translate import BaiduTranslator + + return BaiduTranslator() + raise RuntimeError diff --git a/translate/translator.py b/translate/translator.py new file mode 100644 index 000000000..b394f4e4d --- /dev/null +++ b/translate/translator.py @@ -0,0 +1,12 @@ +""" +Voice service abstract class +""" + + +class Translator(object): + # please use https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes to specify language + def translate(self, query: str, from_lang: str = "", to_lang: str = "en") -> str: + """ + Translate text from one language to another + """ + raise NotImplementedError diff --git a/voice/audio_convert.py b/voice/audio_convert.py index ce0601da1..610170038 100644 --- a/voice/audio_convert.py +++ b/voice/audio_convert.py @@ -34,6 +34,20 @@ def get_pcm_from_wav(wav_path): return wav.readframes(wav.getnframes()) +def any_to_mp3(any_path, mp3_path): + """ + 把任意格式转成mp3文件 + """ + if any_path.endswith(".mp3"): + shutil.copy2(any_path, mp3_path) + return + if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"): + sil_to_wav(any_path, any_path) + any_path = mp3_path + audio = AudioSegment.from_file(any_path) + audio.export(mp3_path, format="mp3") + + def any_to_wav(any_path, wav_path): """ 把任意格式转成wav文件 @@ -41,11 +55,7 @@ def any_to_wav(any_path, wav_path): if any_path.endswith(".wav"): shutil.copy2(any_path, wav_path) return - if ( - any_path.endswith(".sil") - or any_path.endswith(".silk") - or any_path.endswith(".slk") - ): + if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"): return sil_to_wav(any_path, wav_path) audio = AudioSegment.from_file(any_path) audio.export(wav_path, format="wav") @@ -55,11 +65,7 @@ def any_to_sil(any_path, sil_path): """ 把任意格式转成sil文件 """ - if ( - any_path.endswith(".sil") - or any_path.endswith(".silk") - or any_path.endswith(".slk") - ): + if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"): shutil.copy2(any_path, sil_path) return 10000 audio = AudioSegment.from_file(any_path) diff --git a/voice/azure/azure_voice.py b/voice/azure/azure_voice.py index 3ee95043e..1a0a8ed3f 100644 --- a/voice/azure/azure_voice.py +++ b/voice/azure/azure_voice.py @@ -6,6 +6,7 @@ import time import azure.cognitiveservices.speech as speechsdk +from langid import classify from bridge.reply import Reply, ReplyType from common.log import logger @@ -30,7 +31,15 @@ def __init__(self): config = None if not os.path.exists(config_path): # 如果没有配置文件,创建本地配置文件 config = { - "speech_synthesis_voice_name": "zh-CN-XiaoxiaoNeural", + "speech_synthesis_voice_name": "zh-CN-XiaoxiaoNeural", # 识别不出时的默认语音 + "auto_detect": True, # 是否自动检测语言 + "speech_synthesis_zh": "zh-CN-XiaozhenNeural", + "speech_synthesis_en": "en-US-JacobNeural", + "speech_synthesis_ja": "ja-JP-AoiNeural", + "speech_synthesis_ko": "ko-KR-SoonBokNeural", + "speech_synthesis_de": "de-DE-LouisaNeural", + "speech_synthesis_fr": "fr-FR-BrigitteNeural", + "speech_synthesis_es": "es-ES-LaiaNeural", "speech_recognition_language": "zh-CN", } with open(config_path, "w") as fw: @@ -38,59 +47,47 @@ def __init__(self): else: with open(config_path, "r") as fr: config = json.load(fr) + self.config = config self.api_key = conf().get("azure_voice_api_key") self.api_region = conf().get("azure_voice_region") - self.speech_config = speechsdk.SpeechConfig( - subscription=self.api_key, region=self.api_region - ) - self.speech_config.speech_synthesis_voice_name = config[ - "speech_synthesis_voice_name" - ] - self.speech_config.speech_recognition_language = config[ - "speech_recognition_language" - ] + self.speech_config = speechsdk.SpeechConfig(subscription=self.api_key, region=self.api_region) + self.speech_config.speech_synthesis_voice_name = self.config["speech_synthesis_voice_name"] + self.speech_config.speech_recognition_language = self.config["speech_recognition_language"] except Exception as e: logger.warn("AzureVoice init failed: %s, ignore " % e) def voiceToText(self, voice_file): audio_config = speechsdk.AudioConfig(filename=voice_file) - speech_recognizer = speechsdk.SpeechRecognizer( - speech_config=self.speech_config, audio_config=audio_config - ) + speech_recognizer = speechsdk.SpeechRecognizer(speech_config=self.speech_config, audio_config=audio_config) result = speech_recognizer.recognize_once() if result.reason == speechsdk.ResultReason.RecognizedSpeech: - logger.info( - "[Azure] voiceToText voice file name={} text={}".format( - voice_file, result.text - ) - ) + logger.info("[Azure] voiceToText voice file name={} text={}".format(voice_file, result.text)) reply = Reply(ReplyType.TEXT, result.text) else: - logger.error( - "[Azure] voiceToText error, result={}, canceldetails={}".format( - result, result.cancellation_details - ) - ) + logger.error("[Azure] voiceToText error, result={}, canceldetails={}".format(result, result.cancellation_details)) reply = Reply(ReplyType.ERROR, "抱歉,语音识别失败") return reply def textToVoice(self, text): - fileName = TmpDir().path() + "reply-" + str(int(time.time())) + ".wav" + if self.config.get("auto_detect"): + lang = classify(text)[0] + key = "speech_synthesis_" + lang + if key in self.config: + logger.info("[Azure] textToVoice auto detect language={}, voice={}".format(lang, self.config[key])) + self.speech_config.speech_synthesis_voice_name = self.config[key] + else: + self.speech_config.speech_synthesis_voice_name = self.config["speech_synthesis_voice_name"] + else: + self.speech_config.speech_synthesis_voice_name = self.config["speech_synthesis_voice_name"] + # Avoid the same filename under multithreading + fileName = TmpDir().path() + "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7FFFFFFF) + ".wav" audio_config = speechsdk.AudioConfig(filename=fileName) - speech_synthesizer = speechsdk.SpeechSynthesizer( - speech_config=self.speech_config, audio_config=audio_config - ) + speech_synthesizer = speechsdk.SpeechSynthesizer(speech_config=self.speech_config, audio_config=audio_config) result = speech_synthesizer.speak_text(text) if result.reason == speechsdk.ResultReason.SynthesizingAudioCompleted: - logger.info( - "[Azure] textToVoice text={} voice file name={}".format(text, fileName) - ) + logger.info("[Azure] textToVoice text={} voice file name={}".format(text, fileName)) reply = Reply(ReplyType.VOICE, fileName) else: - logger.error( - "[Azure] textToVoice error, result={}, canceldetails={}".format( - result, result.cancellation_details - ) - ) + logger.error("[Azure] textToVoice error, result={}, canceldetails={}".format(result, result.cancellation_details)) reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败") return reply diff --git a/voice/azure/config.json.template b/voice/azure/config.json.template index 2dc2176f9..8f3f546f7 100644 --- a/voice/azure/config.json.template +++ b/voice/azure/config.json.template @@ -1,4 +1,12 @@ { "speech_synthesis_voice_name": "zh-CN-XiaoxiaoNeural", + "auto_detect": true, + "speech_synthesis_zh": "zh-CN-YunxiNeural", + "speech_synthesis_en": "en-US-JacobNeural", + "speech_synthesis_ja": "ja-JP-AoiNeural", + "speech_synthesis_ko": "ko-KR-SoonBokNeural", + "speech_synthesis_de": "de-DE-LouisaNeural", + "speech_synthesis_fr": "fr-FR-BrigitteNeural", + "speech_synthesis_es": "es-ES-LaiaNeural", "speech_recognition_language": "zh-CN" } diff --git a/voice/baidu/baidu_voice.py b/voice/baidu/baidu_voice.py index ccde6c4d1..406157b96 100644 --- a/voice/baidu/baidu_voice.py +++ b/voice/baidu/baidu_voice.py @@ -82,12 +82,11 @@ def textToVoice(self, text): {"spd": self.spd, "pit": self.pit, "vol": self.vol, "per": self.per}, ) if not isinstance(result, dict): - fileName = TmpDir().path() + "reply-" + str(int(time.time())) + ".mp3" + # Avoid the same filename under multithreading + fileName = TmpDir().path() + "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7FFFFFFF) + ".mp3" with open(fileName, "wb") as f: f.write(result) - logger.info( - "[Baidu] textToVoice text={} voice file name={}".format(text, fileName) - ) + logger.info("[Baidu] textToVoice text={} voice file name={}".format(text, fileName)) reply = Reply(ReplyType.VOICE, fileName) else: logger.error("[Baidu] textToVoice error={}".format(result)) diff --git a/voice/voice_factory.py b/voice/factory.py similarity index 100% rename from voice/voice_factory.py rename to voice/factory.py diff --git a/voice/google/google_voice.py b/voice/google/google_voice.py index 4f7b8ade3..6dcadad3d 100644 --- a/voice/google/google_voice.py +++ b/voice/google/google_voice.py @@ -24,11 +24,7 @@ def voiceToText(self, voice_file): audio = self.recognizer.record(source) try: text = self.recognizer.recognize_google(audio, language="zh-CN") - logger.info( - "[Google] voiceToText text={} voice file name={}".format( - text, voice_file - ) - ) + logger.info("[Google] voiceToText text={} voice file name={}".format(text, voice_file)) reply = Reply(ReplyType.TEXT, text) except speech_recognition.UnknownValueError: reply = Reply(ReplyType.ERROR, "抱歉,我听不懂") @@ -39,12 +35,11 @@ def voiceToText(self, voice_file): def textToVoice(self, text): try: - mp3File = TmpDir().path() + "reply-" + str(int(time.time())) + ".mp3" + # Avoid the same filename under multithreading + mp3File = TmpDir().path() + "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7FFFFFFF) + ".mp3" tts = gTTS(text=text, lang="zh") tts.save(mp3File) - logger.info( - "[Google] textToVoice text={} voice file name={}".format(text, mp3File) - ) + logger.info("[Google] textToVoice text={} voice file name={}".format(text, mp3File)) reply = Reply(ReplyType.VOICE, mp3File) except Exception as e: reply = Reply(ReplyType.ERROR, str(e)) diff --git a/voice/openai/openai_voice.py b/voice/openai/openai_voice.py index 06c221b21..b02d92651 100644 --- a/voice/openai/openai_voice.py +++ b/voice/openai/openai_voice.py @@ -22,11 +22,7 @@ def voiceToText(self, voice_file): result = openai.Audio.transcribe("whisper-1", file) text = result["text"] reply = Reply(ReplyType.TEXT, text) - logger.info( - "[Openai] voiceToText text={} voice file name={}".format( - text, voice_file - ) - ) + logger.info("[Openai] voiceToText text={} voice file name={}".format(text, voice_file)) except Exception as e: reply = Reply(ReplyType.ERROR, str(e)) finally: diff --git a/voice/pytts/pytts_voice.py b/voice/pytts/pytts_voice.py index 072e28b41..bd70086d3 100644 --- a/voice/pytts/pytts_voice.py +++ b/voice/pytts/pytts_voice.py @@ -2,6 +2,8 @@ pytts voice service (offline) """ +import os +import sys import time import pyttsx3 @@ -20,19 +22,42 @@ def __init__(self): self.engine.setProperty("rate", 125) # 音量 self.engine.setProperty("volume", 1.0) - for voice in self.engine.getProperty("voices"): - if "Chinese" in voice.name: - self.engine.setProperty("voice", voice.id) + if sys.platform == "win32": + for voice in self.engine.getProperty("voices"): + if "Chinese" in voice.name: + self.engine.setProperty("voice", voice.id) + else: + self.engine.setProperty("voice", "zh") + # If the problem of espeak is fixed, using runAndWait() and remove this startLoop() + # TODO: check if this is work on win32 + self.engine.startLoop(useDriverLoop=False) def textToVoice(self, text): try: - wavFile = TmpDir().path() + "reply-" + str(int(time.time())) + ".wav" + # Avoid the same filename under multithreading + wavFileName = "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7FFFFFFF) + ".wav" + wavFile = TmpDir().path() + wavFileName + logger.info("[Pytts] textToVoice text={} voice file name={}".format(text, wavFile)) + self.engine.save_to_file(text, wavFile) - self.engine.runAndWait() - logger.info( - "[Pytts] textToVoice text={} voice file name={}".format(text, wavFile) - ) + + if sys.platform == "win32": + self.engine.runAndWait() + else: + # In ubuntu, runAndWait do not really wait until the file created. + # It will return once the task queue is empty, but the task is still running in coroutine. + # And if you call runAndWait() and time.sleep() twice, it will stuck, so do not use this. + # If you want to fix this, add self._proxy.setBusy(True) in line 127 in espeak.py, at the beginning of the function save_to_file. + # self.engine.runAndWait() + + # Before espeak fix this problem, we iterate the generator and control the waiting by ourself. + # But this is not the canonical way to use it, for example if the file already exists it also cannot wait. + self.engine.iterate() + while self.engine.isBusy() or wavFileName not in os.listdir(TmpDir().path()): + time.sleep(0.1) + reply = Reply(ReplyType.VOICE, wavFile) + except Exception as e: reply = Reply(ReplyType.ERROR, str(e)) finally: