diff --git a/src/dockerflow/checks/registry.py b/src/dockerflow/checks/registry.py index 6670104..36b3155 100644 --- a/src/dockerflow/checks/registry.py +++ b/src/dockerflow/checks/registry.py @@ -17,19 +17,6 @@ _REGISTERED_CHECKS = {} -def _iscoroutinefunction_or_partial(obj): - """ - Determine if the provided object is a coroutine function or a partial function - that wraps a coroutine function. - - This function should be removed when we drop support for Python 3.7, as this is - handled directly by `inspect.iscoroutinefunction` in Python 3.8. - """ - while isinstance(obj, functools.partial): - obj = obj.func - return inspect.iscoroutinefunction(obj) - - def register(func=None, name=None): """ Register a check callback to be executed from @@ -43,7 +30,7 @@ def register(func=None, name=None): logger.debug("Register Dockerflow check %s", name) - if _iscoroutinefunction_or_partial(func): + if inspect.iscoroutinefunction(func): @functools.wraps(func) async def decorated_function_asyc(*args, **kwargs): @@ -116,7 +103,7 @@ class ChecksResults: async def _run_check_async(check): name, check_fn = check - if _iscoroutinefunction_or_partial(check_fn): + if inspect.iscoroutinefunction(check_fn): errors = await check_fn() else: loop = asyncio.get_event_loop() diff --git a/src/dockerflow/fastapi/views.py b/src/dockerflow/fastapi/views.py index f2dc1cd..96af198 100644 --- a/src/dockerflow/fastapi/views.py +++ b/src/dockerflow/fastapi/views.py @@ -11,12 +11,12 @@ def lbheartbeat(): return {"status": "ok"} -def heartbeat(request: Request, response: Response): +async def heartbeat(request: Request, response: Response): FAILED_STATUS_CODE = int( getattr(request.app.state, "DOCKERFLOW_HEARTBEAT_FAILED_STATUS_CODE", "500") ) - check_results = checks.run_checks( + check_results = await checks.run_checks_async( checks.get_checks().items(), ) diff --git a/tests/fastapi/test_fastapi.py b/tests/fastapi/test_fastapi.py index 4ec97de..73a008c 100644 --- a/tests/fastapi/test_fastapi.py +++ b/tests/fastapi/test_fastapi.py @@ -194,6 +194,53 @@ def return_error(): }, } +def test_heartbeat_sync(client): + @checks.register + def sync_ok(): + return [] + + response = client.get("/__heartbeat__") + assert response.status_code == 200 + assert response.json() == { + "status": "ok", + "checks": {"sync_ok": "ok"}, + "details": {}, + } + + +def test_heartbeat_async(client): + @checks.register + async def async_ok(): + return [] + + response = client.get("/__heartbeat__") + assert response.status_code == 200 + assert response.json() == { + "status": "ok", + "checks": {"async_ok": "ok"}, + "details": {}, + } + + +def test_heartbeat_mixed_sync(client): + @checks.register + def sync_ok(): + return [] + @checks.register + async def async_ok(): + return [] + + response = client.get("/__heartbeat__") + assert response.status_code == 200 + assert response.json() == { + "status": "ok", + "checks": { + "sync_ok": "ok", + "async_ok": "ok", + }, + "details": {}, + } + def test_heartbeat_head(client): @checks.register