From c0973fa83bd2af1f5a5a15588006a848dd49c3e9 Mon Sep 17 00:00:00 2001
From: rf_tar_railt <3165388245@qq.com>
Date: Sat, 15 Jun 2024 20:03:23 +0800
Subject: [PATCH] :bug: version 0.12.2 fix bot disconnect
---
nonebot/adapters/satori/adapter.py | 13 ++++++++++++-
pyproject.toml | 2 +-
2 files changed, 13 insertions(+), 2 deletions(-)
diff --git a/nonebot/adapters/satori/adapter.py b/nonebot/adapters/satori/adapter.py
index 81dcf11..353082e 100644
--- a/nonebot/adapters/satori/adapter.py
+++ b/nonebot/adapters/satori/adapter.py
@@ -1,5 +1,6 @@
import json
import asyncio
+from collections import defaultdict
from typing_extensions import override
from typing import Any, Literal, Optional
@@ -49,6 +50,7 @@ def __init__(self, driver: Driver, **kwargs: Any):
self.satori_config: Config = get_plugin_config(Config)
self.tasks: list[asyncio.Task] = [] # 存储 ws 任务
self.sequences: dict[str, int] = {} # 存储 连接序列号
+ self._bots: defaultdict[str, set[str]] = defaultdict(set) # 存储 identity 和 bot_id 的映射
self.setup()
@classmethod
@@ -89,6 +91,9 @@ async def shutdown(self) -> None:
*(asyncio.wait_for(task, timeout=10) for task in self.tasks),
return_exceptions=True,
)
+ self.tasks.clear()
+ self.sequences.clear()
+ self._bots.clear()
@staticmethod
def payload_to_json(payload: Payload) -> str:
@@ -137,12 +142,14 @@ async def _authenticate(self, info: ClientInfo, ws: WebSocket) -> Optional[Liter
continue
if login.self_id not in self.bots:
bot = Bot(self, login.self_id, login, info)
+ self._bots[info.identity].add(bot.self_id)
self.bot_connect(bot)
log(
"INFO",
f"Bot {escape_tag(bot.self_id)} connected",
)
else:
+ self._bots[info.identity].add(login.self_id)
bot = self.bots[login.self_id]
bot._update(login)
if not self.bots:
@@ -196,10 +203,11 @@ async def ws(self, info: ClientInfo) -> None:
if heartbeat_task:
heartbeat_task.cancel()
heartbeat_task = None
- bots = list(self.bots.values())
+ bots = [self.bots[bot_id] for bot_id in self._bots[info.identity]]
for bot in bots:
self.bot_disconnect(bot)
bots.clear()
+ self._bots[info.identity].clear()
except Exception as e:
log(
"ERROR",
@@ -232,6 +240,7 @@ async def _loop(self, info: ClientInfo, ws: WebSocket):
else:
if isinstance(event, LoginAddedEvent):
bot = Bot(self, event.self_id, event.login, info)
+ self._bots[info.identity].add(bot.self_id)
self.bot_connect(bot)
log(
"INFO",
@@ -239,6 +248,7 @@ async def _loop(self, info: ClientInfo, ws: WebSocket):
)
elif isinstance(event, LoginRemovedEvent):
self.bot_disconnect(self.bots[event.self_id])
+ self._bots[info.identity].discard(event.self_id)
log(
"INFO",
f"Bot {escape_tag(event.self_id)} disconnected",
@@ -246,6 +256,7 @@ async def _loop(self, info: ClientInfo, ws: WebSocket):
continue
elif isinstance(event, LoginUpdatedEvent):
self.bots[event.self_id]._update(event.login)
+ self._bots[info.identity].add(event.self_id)
if not (bot := self.bots.get(event.self_id)):
log(
"WARNING",
diff --git a/pyproject.toml b/pyproject.toml
index 4540fa5..7a8c320 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "nonebot-adapter-satori"
-version = "0.12.1"
+version = "0.12.2"
description = "Satori Protocol Adapter for Nonebot2"
authors = [
{name = "RF-Tar-Railt",email = "rf_tar_railt@qq.com"},