Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: 重构驱动器 lifespan 方法 #1860

Merged
merged 3 commits into from
Mar 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions nonebot/drivers/_lifespan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import Any, List, Union, Callable, Awaitable, cast

from nonebot.utils import run_sync, is_coroutine_callable

SYNC_LIFESPAN_FUNC = Callable[[], Any]
ASYNC_LIFESPAN_FUNC = Callable[[], Awaitable[Any]]
LIFESPAN_FUNC = Union[SYNC_LIFESPAN_FUNC, ASYNC_LIFESPAN_FUNC]


class Lifespan:
def __init__(self) -> None:
self._startup_funcs: List[LIFESPAN_FUNC] = []
self._shutdown_funcs: List[LIFESPAN_FUNC] = []

def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
self._startup_funcs.append(func)
return func

def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
self._shutdown_funcs.append(func)
return func

@staticmethod
async def _run_lifespan_func(
funcs: List[LIFESPAN_FUNC],
) -> None:
for func in funcs:
if is_coroutine_callable(func):
await cast(ASYNC_LIFESPAN_FUNC, func)()
else:
await run_sync(cast(SYNC_LIFESPAN_FUNC, func))()

async def startup(self) -> None:
if self._startup_funcs:
await self._run_lifespan_func(self._startup_funcs)

async def shutdown(self) -> None:
if self._shutdown_funcs:
await self._run_lifespan_func(self._shutdown_funcs)

async def __aenter__(self) -> None:
await self.startup()

async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
await self.shutdown()
2 changes: 1 addition & 1 deletion nonebot/drivers/aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

try:
import aiohttp
except ImportError as e: # pragma: no cover
except ModuleNotFoundError as e: # pragma: no cover
raise ImportError(
"Please install aiohttp first to use this driver. `pip install nonebot2[aiohttp]`"
) from e
Expand Down
27 changes: 19 additions & 8 deletions nonebot/drivers/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import logging
import contextlib
from functools import wraps
from typing import Any, Dict, List, Tuple, Union, Callable, Optional
from typing import Any, Dict, List, Tuple, Union, Optional

from pydantic import BaseSettings

Expand All @@ -32,12 +32,14 @@
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers import ReverseDriver, HTTPServerSetup, WebSocketServerSetup

from ._lifespan import LIFESPAN_FUNC, Lifespan

try:
import uvicorn
from fastapi.responses import Response
from fastapi import FastAPI, Request, UploadFile, status
from starlette.websockets import WebSocket, WebSocketState, WebSocketDisconnect
except ImportError as e: # pragma: no cover
except ModuleNotFoundError as e: # pragma: no cover
raise ImportError(
"Please install FastAPI by using `pip install nonebot2[fastapi]`"
) from e
Expand Down Expand Up @@ -92,7 +94,10 @@ def __init__(self, env: Env, config: NoneBotConfig):

self.fastapi_config: Config = Config(**config.dict())

self._lifespan = Lifespan()

self._server_app = FastAPI(
lifespan=self._lifespan_manager,
openapi_url=self.fastapi_config.fastapi_openapi_url,
docs_url=self.fastapi_config.fastapi_docs_url,
redoc_url=self.fastapi_config.fastapi_redoc_url,
Expand Down Expand Up @@ -148,14 +153,20 @@ async def _handle(websocket: WebSocket) -> None:
)

@overrides(ReverseDriver)
def on_startup(self, func: Callable) -> Callable:
"""参考文档: `Events <https://fastapi.tiangolo.com/advanced/events/#startup-event>`_"""
return self.server_app.on_event("startup")(func)
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
return self._lifespan.on_startup(func)

@overrides(ReverseDriver)
def on_shutdown(self, func: Callable) -> Callable:
"""参考文档: `Events <https://fastapi.tiangolo.com/advanced/events/#shutdown-event>`_"""
return self.server_app.on_event("shutdown")(func)
def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
return self._lifespan.on_shutdown(func)

@contextlib.asynccontextmanager
async def _lifespan_manager(self, app: FastAPI):
await self._lifespan.startup()
try:
yield
finally:
await self._lifespan.shutdown()

@overrides(ReverseDriver)
def run(
Expand Down
2 changes: 1 addition & 1 deletion nonebot/drivers/httpx.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

try:
import httpx
except ImportError as e: # pragma: no cover
except ModuleNotFoundError as e: # pragma: no cover
raise ImportError(
"Please install httpx by using `pip install nonebot2[httpx]`"
) from e
Expand Down
64 changes: 24 additions & 40 deletions nonebot/drivers/none.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import signal
import asyncio
import threading
from typing import Set, Union, Callable, Awaitable, cast

from nonebot.log import logger
from nonebot.consts import WINDOWS
Expand All @@ -22,7 +21,8 @@
from nonebot.drivers import Driver as BaseDriver
from nonebot.utils import run_sync, is_coroutine_callable

HOOK_FUNC = Union[Callable[[], None], Callable[[], Awaitable[None]]]
from ._lifespan import LIFESPAN_FUNC, Lifespan

HANDLED_SIGNALS = (
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`.
Expand All @@ -36,8 +36,9 @@ class Driver(BaseDriver):

def __init__(self, env: Env, config: Config):
super().__init__(env, config)
self.startup_funcs: Set[HOOK_FUNC] = set()
self.shutdown_funcs: Set[HOOK_FUNC] = set()

self._lifespan = Lifespan()

self.should_exit: asyncio.Event = asyncio.Event()
self.force_exit: bool = False

Expand All @@ -54,20 +55,18 @@ def logger(self):
return logger

@overrides(BaseDriver)
def on_startup(self, func: HOOK_FUNC) -> HOOK_FUNC:
def on_startup(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""
注册一个启动时执行的函数
"""
self.startup_funcs.add(func)
return func
return self._lifespan.on_startup(func)

@overrides(BaseDriver)
def on_shutdown(self, func: HOOK_FUNC) -> HOOK_FUNC:
def on_shutdown(self, func: LIFESPAN_FUNC) -> LIFESPAN_FUNC:
"""
注册一个停止时执行的函数
"""
self.shutdown_funcs.add(func)
return func
return self._lifespan.on_shutdown(func)

@overrides(BaseDriver)
def run(self, *args, **kwargs):
Expand All @@ -85,21 +84,13 @@ async def _serve(self):
await self._shutdown()

async def _startup(self):
# run startup
cors = [
cast(Callable[..., Awaitable[None]], startup)()
if is_coroutine_callable(startup)
else run_sync(startup)()
for startup in self.startup_funcs
]
if cors:
try:
await asyncio.gather(*cors)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running startup function. "
"Ignored!</bg #f8bbd0></r>"
)
try:
await self._lifespan.startup()
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running startup function. "
"Ignored!</bg #f8bbd0></r>"
)

logger.info("Application startup completed.")

Expand All @@ -110,21 +101,14 @@ async def _shutdown(self):
logger.info("Shutting down")

logger.info("Waiting for application shutdown.")
# run shutdown
cors = [
cast(Callable[..., Awaitable[None]], shutdown)()
if is_coroutine_callable(shutdown)
else run_sync(shutdown)()
for shutdown in self.shutdown_funcs
]
if cors:
try:
await asyncio.gather(*cors)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running shutdown function. "
"Ignored!</bg #f8bbd0></r>"
)

try:
await self._lifespan.shutdown()
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running shutdown function. "
"Ignored!</bg #f8bbd0></r>"
)

for task in asyncio.all_tasks():
if task is not asyncio.current_task() and not task.done():
Expand Down
2 changes: 1 addition & 1 deletion nonebot/drivers/quart.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from quart import Quart, Request, Response
from quart.datastructures import FileStorage
from quart import Websocket as QuartWebSocket
except ImportError as e: # pragma: no cover
except ModuleNotFoundError as e: # pragma: no cover
raise ImportError(
"Please install Quart by using `pip install nonebot2[quart]`"
) from e
Expand Down
2 changes: 1 addition & 1 deletion nonebot/drivers/websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
try:
from websockets.exceptions import ConnectionClosed
from websockets.legacy.client import Connect, WebSocketClientProtocol
except ImportError as e: # pragma: no cover
except ModuleNotFoundError as e: # pragma: no cover
raise ImportError(
"Please install websockets by using `pip install nonebot2[websockets]`"
) from e
Expand Down
34 changes: 34 additions & 0 deletions tests/test_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from nonebot import _resolve_combine_expr
from nonebot.dependencies import Dependent
from nonebot.exception import WebSocketClosed
from nonebot.drivers._lifespan import Lifespan
from nonebot.drivers import (
URL,
Driver,
Expand All @@ -36,6 +37,39 @@ def load_driver(request: pytest.FixtureRequest) -> Driver:
return DriverClass(Env(environment=global_driver.env), global_driver.config)


@pytest.mark.asyncio
async def test_lifespan():
lifespan = Lifespan()

start_log = []
shutdown_log = []

@lifespan.on_startup
async def _startup1():
assert start_log == []
start_log.append(1)

@lifespan.on_startup
async def _startup2():
assert start_log == [1]
start_log.append(2)

@lifespan.on_shutdown
async def _shutdown1():
assert shutdown_log == []
shutdown_log.append(1)

@lifespan.on_shutdown
async def _shutdown2():
assert shutdown_log == [1]
shutdown_log.append(2)

async with lifespan:
assert start_log == [1, 2]

assert shutdown_log == [1, 2]


@pytest.mark.asyncio
@pytest.mark.parametrize(
"driver",
Expand Down