From c0973fa83bd2af1f5a5a15588006a848dd49c3e9 Mon Sep 17 00:00:00 2001 From: rf_tar_railt <3165388245@qq.com> Date: Sat, 15 Jun 2024 20:03:23 +0800 Subject: [PATCH] :bug: version 0.12.2 fix bot disconnect --- nonebot/adapters/satori/adapter.py | 13 ++++++++++++- pyproject.toml | 2 +- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/nonebot/adapters/satori/adapter.py b/nonebot/adapters/satori/adapter.py index 81dcf11..353082e 100644 --- a/nonebot/adapters/satori/adapter.py +++ b/nonebot/adapters/satori/adapter.py @@ -1,5 +1,6 @@ import json import asyncio +from collections import defaultdict from typing_extensions import override from typing import Any, Literal, Optional @@ -49,6 +50,7 @@ def __init__(self, driver: Driver, **kwargs: Any): self.satori_config: Config = get_plugin_config(Config) self.tasks: list[asyncio.Task] = [] # 存储 ws 任务 self.sequences: dict[str, int] = {} # 存储 连接序列号 + self._bots: defaultdict[str, set[str]] = defaultdict(set) # 存储 identity 和 bot_id 的映射 self.setup() @classmethod @@ -89,6 +91,9 @@ async def shutdown(self) -> None: *(asyncio.wait_for(task, timeout=10) for task in self.tasks), return_exceptions=True, ) + self.tasks.clear() + self.sequences.clear() + self._bots.clear() @staticmethod def payload_to_json(payload: Payload) -> str: @@ -137,12 +142,14 @@ async def _authenticate(self, info: ClientInfo, ws: WebSocket) -> Optional[Liter continue if login.self_id not in self.bots: bot = Bot(self, login.self_id, login, info) + self._bots[info.identity].add(bot.self_id) self.bot_connect(bot) log( "INFO", f"Bot {escape_tag(bot.self_id)} connected", ) else: + self._bots[info.identity].add(login.self_id) bot = self.bots[login.self_id] bot._update(login) if not self.bots: @@ -196,10 +203,11 @@ async def ws(self, info: ClientInfo) -> None: if heartbeat_task: heartbeat_task.cancel() heartbeat_task = None - bots = list(self.bots.values()) + bots = [self.bots[bot_id] for bot_id in self._bots[info.identity]] for bot in bots: self.bot_disconnect(bot) bots.clear() + self._bots[info.identity].clear() except Exception as e: log( "ERROR", @@ -232,6 +240,7 @@ async def _loop(self, info: ClientInfo, ws: WebSocket): else: if isinstance(event, LoginAddedEvent): bot = Bot(self, event.self_id, event.login, info) + self._bots[info.identity].add(bot.self_id) self.bot_connect(bot) log( "INFO", @@ -239,6 +248,7 @@ async def _loop(self, info: ClientInfo, ws: WebSocket): ) elif isinstance(event, LoginRemovedEvent): self.bot_disconnect(self.bots[event.self_id]) + self._bots[info.identity].discard(event.self_id) log( "INFO", f"Bot {escape_tag(event.self_id)} disconnected", @@ -246,6 +256,7 @@ async def _loop(self, info: ClientInfo, ws: WebSocket): continue elif isinstance(event, LoginUpdatedEvent): self.bots[event.self_id]._update(event.login) + self._bots[info.identity].add(event.self_id) if not (bot := self.bots.get(event.self_id)): log( "WARNING", diff --git a/pyproject.toml b/pyproject.toml index 4540fa5..7a8c320 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "nonebot-adapter-satori" -version = "0.12.1" +version = "0.12.2" description = "Satori Protocol Adapter for Nonebot2" authors = [ {name = "RF-Tar-Railt",email = "rf_tar_railt@qq.com"},