Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: 添加 State 响应器触发消息注入 #1315

Merged
merged 10 commits into from
Oct 12, 2022
8 changes: 8 additions & 0 deletions nonebot/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,11 @@
"""正则匹配 group 元组存储 key"""
REGEX_DICT: Literal["_matched_dict"] = "_matched_dict"
"""正则匹配 group 字典存储 key"""
STARTSWITH_KEY: Literal["_startswith"] = "_startswith"
"""响应触发前缀 key"""
ENDSWITH_KEY: Literal["_endswith"] = "_endswith"
"""响应触发后缀 key"""
FULLMATCH_KEY: Literal["_fullmatch"] = "_fullmatch"
"""响应触发完整消息 key"""
KEYWORD_KEY: Literal["_keyword"] = "_keyword"
"""响应触发关键字 key"""
40 changes: 40 additions & 0 deletions nonebot/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,14 @@
SHELL_ARGS,
SHELL_ARGV,
CMD_ARG_KEY,
KEYWORD_KEY,
RAW_CMD_KEY,
REGEX_GROUP,
ENDSWITH_KEY,
CMD_START_KEY,
FULLMATCH_KEY,
REGEX_MATCHED,
STARTSWITH_KEY,
)


Expand Down Expand Up @@ -153,6 +157,42 @@ def RegexDict() -> Dict[str, Any]:
return Depends(_regex_dict, use_cache=False)


def _startswith(state: T_State) -> str:
return state[STARTSWITH_KEY]


def Startswith() -> str:
"""响应触发前缀"""
return Depends(_startswith, use_cache=False)


def _endswith(state: T_State) -> str:
return state[ENDSWITH_KEY]


def Endswith() -> str:
"""响应触发后缀"""
return Depends(_endswith, use_cache=False)


def _fullmatch(state: T_State) -> str:
return state[FULLMATCH_KEY]


def Fullmatch() -> str:
"""响应触发完整消息"""
return Depends(_fullmatch, use_cache=False)


def _keyword(state: T_State) -> str:
return state[KEYWORD_KEY]


def Keyword() -> str:
"""响应触发关键字"""
return Depends(_keyword, use_cache=False)


def Received(id: Optional[str] = None, default: Any = None) -> Any:
"""`receive` 事件参数"""

Expand Down
58 changes: 37 additions & 21 deletions nonebot/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from typing import (
IO,
TYPE_CHECKING,
Any,
List,
Type,
Tuple,
Expand Down Expand Up @@ -48,10 +47,14 @@
SHELL_ARGS,
SHELL_ARGV,
CMD_ARG_KEY,
KEYWORD_KEY,
RAW_CMD_KEY,
REGEX_GROUP,
ENDSWITH_KEY,
CMD_START_KEY,
FULLMATCH_KEY,
REGEX_MATCHED,
STARTSWITH_KEY,
)

T = TypeVar("T")
Expand Down Expand Up @@ -136,20 +139,21 @@ def __eq__(self, other: object) -> bool:
def __hash__(self) -> int:
return hash((frozenset(self.msg), self.ignorecase))

async def __call__(self, event: Event) -> bool:
async def __call__(self, event: Event, state: T_State) -> bool:
if event.get_type() != "message":
return False
try:
text = event.get_plaintext()
except Exception:
return False
return bool(
re.match(
f"^(?:{'|'.join(re.escape(prefix) for prefix in self.msg)})",
text,
re.IGNORECASE if self.ignorecase else 0,
)
)
if match := re.match(
f"^(?:{'|'.join(re.escape(prefix) for prefix in self.msg)})",
text,
re.IGNORECASE if self.ignorecase else 0,
):
state[STARTSWITH_KEY] = match.group()
return True
return False


def startswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Rule:
Expand Down Expand Up @@ -192,20 +196,21 @@ def __eq__(self, other: object) -> bool:
def __hash__(self) -> int:
return hash((frozenset(self.msg), self.ignorecase))

async def __call__(self, event: Event) -> bool:
async def __call__(self, event: Event, state: T_State) -> bool:
if event.get_type() != "message":
return False
try:
text = event.get_plaintext()
except Exception:
return False
return bool(
re.search(
f"(?:{'|'.join(re.escape(prefix) for prefix in self.msg)})$",
text,
re.IGNORECASE if self.ignorecase else 0,
)
)
if match := re.search(
f"(?:{'|'.join(re.escape(suffix) for suffix in self.msg)})$",
text,
re.IGNORECASE if self.ignorecase else 0,
):
state[ENDSWITH_KEY] = match.group()
return True
return False


def endswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Rule:
Expand Down Expand Up @@ -248,14 +253,20 @@ def __eq__(self, other: object) -> bool:
def __hash__(self) -> int:
return hash((frozenset(self.msg), self.ignorecase))

async def __call__(self, event: Event) -> bool:
async def __call__(self, event: Event, state: T_State) -> bool:
if event.get_type() != "message":
return False
try:
text = event.get_plaintext()
except Exception:
return False
return (text.casefold() if self.ignorecase else text) in self.msg
if not text:
return False
text = text.casefold() if self.ignorecase else text
if text in self.msg:
state[FULLMATCH_KEY] = text
return True
return False


def fullmatch(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Rule:
Expand Down Expand Up @@ -294,14 +305,19 @@ def __eq__(self, other: object) -> bool:
def __hash__(self) -> int:
return hash(frozenset(self.keywords))

async def __call__(self, event: Event) -> bool:
async def __call__(self, event: Event, state: T_State) -> bool:
if event.get_type() != "message":
return False
try:
text = event.get_plaintext()
except Exception:
return False
return bool(text and any(keyword in text for keyword in self.keywords))
if not text:
return False
if key := next((k for k in self.keywords if k in text), None):
state[KEYWORD_KEY] = key
return True
return False


def keyword(*keywords: str) -> Rule:
Expand Down
20 changes: 20 additions & 0 deletions tests/plugins/param/param_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@
from nonebot.adapters import Message
from nonebot.params import (
Command,
Keyword,
Endswith,
Fullmatch,
RegexDict,
CommandArg,
RawCommand,
RegexGroup,
Startswith,
CommandStart,
RegexMatched,
ShellCommandArgs,
Expand Down Expand Up @@ -65,3 +69,19 @@ async def regex_group(regex_group: Tuple = RegexGroup()) -> Tuple:

async def regex_matched(regex_matched: str = RegexMatched()) -> str:
return regex_matched


async def startswith(startswith: str = Startswith()) -> str:
return startswith


async def endswith(endswith: str = Endswith()) -> str:
return endswith


async def fullmatch(fullmatch: str = Fullmatch()) -> str:
return fullmatch


async def keyword(keyword: str = Keyword()) -> str:
return keyword
36 changes: 36 additions & 0 deletions tests/test_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,15 +168,23 @@ async def test_state(app: App, load_plugin):
SHELL_ARGS,
SHELL_ARGV,
CMD_ARG_KEY,
KEYWORD_KEY,
RAW_CMD_KEY,
REGEX_GROUP,
ENDSWITH_KEY,
CMD_START_KEY,
FULLMATCH_KEY,
REGEX_MATCHED,
STARTSWITH_KEY,
)
from plugins.param.param_state import (
state,
command,
keyword,
endswith,
fullmatch,
regex_dict,
startswith,
command_arg,
raw_command,
regex_group,
Expand All @@ -201,6 +209,10 @@ async def test_state(app: App, load_plugin):
REGEX_MATCHED: "[cq:test,arg=value]",
REGEX_GROUP: ("test", "arg=value"),
REGEX_DICT: {"type": "test", "arg": "value"},
STARTSWITH_KEY: "startswith",
ENDSWITH_KEY: "endswith",
FULLMATCH_KEY: "fullmatch",
KEYWORD_KEY: "keyword",
}

async with app.test_dependent(state, allow_types=[StateParam]) as ctx:
Expand Down Expand Up @@ -271,6 +283,30 @@ async def test_state(app: App, load_plugin):
ctx.pass_params(state=fake_state)
ctx.should_return(fake_state[REGEX_DICT])

async with app.test_dependent(
startswith, allow_types=[StateParam, DependParam]
) as ctx:
ctx.pass_params(state=fake_state)
ctx.should_return(fake_state[STARTSWITH_KEY])

async with app.test_dependent(
endswith, allow_types=[StateParam, DependParam]
) as ctx:
ctx.pass_params(state=fake_state)
ctx.should_return(fake_state[ENDSWITH_KEY])

async with app.test_dependent(
fullmatch, allow_types=[StateParam, DependParam]
) as ctx:
ctx.pass_params(state=fake_state)
ctx.should_return(fake_state[FULLMATCH_KEY])

async with app.test_dependent(
keyword, allow_types=[StateParam, DependParam]
) as ctx:
ctx.pass_params(state=fake_state)
ctx.should_return(fake_state[KEYWORD_KEY])


@pytest.mark.asyncio
async def test_matcher(app: App, load_plugin):
Expand Down
Loading