Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: 补充依赖注入部分情况下类型错误时的日志提示 #2343

Merged
merged 1 commit into from
Sep 4, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 17 additions & 22 deletions nonebot/dependencies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,17 +101,21 @@ def __repr__(self) -> str:
)

async def __call__(self, **kwargs: Any) -> R:
# do pre-check
await self.check(**kwargs)
try:
# do pre-check
await self.check(**kwargs)

# solve param values
values = await self.solve(**kwargs)
# solve param values
values = await self.solve(**kwargs)

# call function
if is_coroutine_callable(self.call):
return await cast(Callable[..., Awaitable[R]], self.call)(**values)
else:
return await run_sync(cast(Callable[..., R], self.call))(**values)
# call function
if is_coroutine_callable(self.call):
return await cast(Callable[..., Awaitable[R]], self.call)(**values)
else:
return await run_sync(cast(Callable[..., R], self.call))(**values)
except SkippedException as e:
logger.trace(f"{self} skipped due to {e}")
raise

@staticmethod
def parse_params(
Expand Down Expand Up @@ -195,19 +199,10 @@ def parse(
return cls(call, params, parameterless_params)

async def check(self, **params: Any) -> None:
try:
await asyncio.gather(
*(param._check(**params) for param in self.parameterless)
)
await asyncio.gather(
*(
cast(Param, param.field_info)._check(**params)
for param in self.params
)
)
except SkippedException as e:
logger.trace(f"{self} skipped due to {e}")
raise
await asyncio.gather(*(param._check(**params) for param in self.parameterless))
await asyncio.gather(
*(cast(Param, param.field_info)._check(**params) for param in self.params)
)

async def _solve_field(self, field: ModelField, params: Dict[str, Any]) -> Any:
param = cast(Param, field.field_info)
Expand Down