From bb63ff155250a854783bcec85ba4702cf4ae3ca6 Mon Sep 17 00:00:00 2001 From: Mourits de Beer <31511766+ff137@users.noreply.github.com> Date: Tue, 5 Nov 2024 20:11:18 +0200 Subject: [PATCH] :sparkles: Handle NotFound and UnprocessableEntity errors in middleware (#3327) * :sparkles: Handle NotFound and UnprocessableEntity exceptions at info log level Signed-off-by: ff137 * :art: replace traceback print with exception log Signed-off-by: ff137 * :art: replace deprecated .warn with .warning Signed-off-by: ff137 * :art: setting ledger to read-only should not print error log Signed-off-by: ff137 * :sparkles: extract the marshmallow validation error message from the nested exception Signed-off-by: ff137 * :art: modify import for consistency Signed-off-by: ff137 * :white_check_mark: test coverage for new method Signed-off-by: ff137 * :white_check_mark: test coverage for exception handling changes Signed-off-by: ff137 * :art: refactor for reduced complexity Signed-off-by: ff137 * :art: reorder exceptions by severity Signed-off-by: ff137 * :art: update log level Signed-off-by: ff137 * :art: Signed-off-by: ff137 --------- Signed-off-by: ff137 Co-authored-by: jamshale <31809382+jamshale@users.noreply.github.com> --- acapy_agent/admin/server.py | 94 ++++++----- acapy_agent/admin/tests/test_admin_server.py | 158 ++++++++++++++++++ acapy_agent/askar/profile.py | 2 +- acapy_agent/askar/profile_anon.py | 2 +- .../protocols/out_of_band/v1_0/manager.py | 2 +- acapy_agent/utils/extract_validation_error.py | 16 ++ .../tests/test_extract_validation_error.py | 77 +++++++++ 7 files changed, 308 insertions(+), 43 deletions(-) create mode 100644 acapy_agent/utils/extract_validation_error.py create mode 100644 acapy_agent/utils/tests/test_extract_validation_error.py diff --git a/acapy_agent/admin/server.py b/acapy_agent/admin/server.py index 883db0dda4..84650ed11b 100644 --- a/acapy_agent/admin/server.py +++ b/acapy_agent/admin/server.py @@ -13,8 +13,6 @@ from aiohttp_apispec import setup_aiohttp_apispec, validation_middleware from uuid_utils import uuid4 -from acapy_agent.wallet import singletons - from ..config.injection_context import InjectionContext from ..config.logging import context_wallet_id from ..core.event_bus import Event, EventBus @@ -31,9 +29,11 @@ from ..transport.outbound.status import OutboundSendStatus from ..transport.queue.basic import BasicMessageQueue from ..utils import general as general_utils +from ..utils.extract_validation_error import extract_validation_error_message from ..utils.stats import Collector from ..utils.task_queue import TaskQueue from ..version import __version__ +from ..wallet import singletons from ..wallet.anoncreds_upgrade import check_upgrade_completion_loop from .base_server import BaseAdminServer from .error import AdminSetupError @@ -68,6 +68,8 @@ anoncreds_wallets = singletons.IsAnoncredsSingleton().wallets in_progress_upgrades = singletons.UpgradeInProgressSingleton() +status_paths = ("/status/live", "/status/ready") + class AdminResponder(BaseResponder): """Handle outgoing messages from message handlers.""" @@ -134,44 +136,56 @@ def send_fn(self) -> Coroutine: async def ready_middleware(request: web.BaseRequest, handler: Coroutine): """Only continue if application is ready to take work.""" - if str(request.rel_url).rstrip("/") in ( - "/status/live", - "/status/ready", - ) or request.app._state.get("ready"): - try: - return await handler(request) - except (LedgerConfigError, LedgerTransactionError) as e: - # fatal, signal server shutdown - LOGGER.error("Shutdown with %s", str(e)) - request.app._state["ready"] = False - request.app._state["alive"] = False - raise - except web.HTTPFound as e: - # redirect, typically / -> /api/doc - LOGGER.info("Handler redirect to: %s", e.location) - raise - except (web.HTTPUnauthorized, jwt.InvalidTokenError, InvalidTokenError) as e: - LOGGER.info( - "Unauthorized access during %s %s: %s", request.method, request.path, e - ) - raise web.HTTPUnauthorized(reason=str(e)) from e - except (web.HTTPBadRequest, MultitenantManagerError) as e: - LOGGER.info("Bad request during %s %s: %s", request.method, request.path, e) - raise web.HTTPBadRequest(reason=str(e)) from e - except asyncio.CancelledError: - # redirection spawns new task and cancels old - LOGGER.debug("Task cancelled") - raise - except Exception as e: - # some other error? - LOGGER.error("Handler error with exception: %s", str(e)) - import traceback - - print("\n=================") - traceback.print_exc() - raise - - raise web.HTTPServiceUnavailable(reason="Shutdown in progress") + is_status_check = str(request.rel_url).rstrip("/") in status_paths + is_app_ready = request.app._state.get("ready") + + if not (is_status_check or is_app_ready): + raise web.HTTPServiceUnavailable(reason="Shutdown in progress") + + try: + return await handler(request) + except web.HTTPFound as e: + # redirect, typically / -> /api/doc + LOGGER.info("Handler redirect to: %s", e.location) + raise + except asyncio.CancelledError: + # redirection spawns new task and cancels old + LOGGER.debug("Task cancelled") + raise + except (web.HTTPUnauthorized, jwt.InvalidTokenError, InvalidTokenError) as e: + LOGGER.info( + "Unauthorized access during %s %s: %s", request.method, request.path, e + ) + raise web.HTTPUnauthorized(reason=str(e)) from e + except (web.HTTPBadRequest, MultitenantManagerError) as e: + LOGGER.info("Bad request during %s %s: %s", request.method, request.path, e) + raise web.HTTPBadRequest(reason=str(e)) from e + except (web.HTTPNotFound, StorageNotFoundError) as e: + LOGGER.info( + "Not Found error occurred during %s %s: %s", + request.method, + request.path, + e, + ) + raise web.HTTPNotFound(reason=str(e)) from e + except web.HTTPUnprocessableEntity as e: + validation_error_message = extract_validation_error_message(e) + LOGGER.info( + "Unprocessable Entity occurred during %s %s: %s", + request.method, + request.path, + validation_error_message, + ) + raise + except (LedgerConfigError, LedgerTransactionError) as e: + # fatal, signal server shutdown + LOGGER.critical("Shutdown with %s", str(e)) + request.app._state["ready"] = False + request.app._state["alive"] = False + raise + except Exception as e: + LOGGER.exception("Handler error with exception:", exc_info=e) + raise @web.middleware diff --git a/acapy_agent/admin/tests/test_admin_server.py b/acapy_agent/admin/tests/test_admin_server.py index 896215183d..39c7c32d91 100644 --- a/acapy_agent/admin/tests/test_admin_server.py +++ b/acapy_agent/admin/tests/test_admin_server.py @@ -3,9 +3,11 @@ from typing import Optional from unittest import IsolatedAsyncioTestCase +import jwt import pytest from aiohttp import ClientSession, DummyCookieJar, TCPConnector, web from aiohttp.test_utils import unused_port +from marshmallow import ValidationError from acapy_agent.tests import mock from acapy_agent.wallet import singletons @@ -16,7 +18,9 @@ from ...core.goal_code_registry import GoalCodeRegistry from ...core.in_memory import InMemoryProfile from ...core.protocol_registry import ProtocolRegistry +from ...multitenant.error import MultitenantManagerError from ...storage.base import BaseStorage +from ...storage.error import StorageNotFoundError from ...storage.record import StorageRecord from ...storage.type import RECORD_TYPE_ACAPY_UPGRADING from ...utils.stats import Collector @@ -108,6 +112,160 @@ async def test_ready_middleware(self): with self.assertRaises(KeyError): await test_module.ready_middleware(request, handler) + async def test_ready_middleware_http_unauthorized(self): + """Test handling of web.HTTPUnauthorized and related exceptions.""" + with mock.patch.object(test_module, "LOGGER", mock.MagicMock()) as mock_logger: + mock_logger.info = mock.MagicMock() + + request = mock.MagicMock( + method="GET", + path="/unauthorized", + app=mock.MagicMock(_state={"ready": True}), + ) + + # Test web.HTTPUnauthorized + handler = mock.CoroutineMock( + side_effect=web.HTTPUnauthorized(reason="Unauthorized") + ) + with self.assertRaises(web.HTTPUnauthorized): + await test_module.ready_middleware(request, handler) + mock_logger.info.assert_called_with( + "Unauthorized access during %s %s: %s", + request.method, + request.path, + handler.side_effect, + ) + + # Test jwt.InvalidTokenError + handler = mock.CoroutineMock( + side_effect=jwt.InvalidTokenError("Invalid token") + ) + with self.assertRaises(web.HTTPUnauthorized): + await test_module.ready_middleware(request, handler) + mock_logger.info.assert_called_with( + "Unauthorized access during %s %s: %s", + request.method, + request.path, + handler.side_effect, + ) + + # Test InvalidTokenError + handler = mock.CoroutineMock( + side_effect=test_module.InvalidTokenError("Token error") + ) + with self.assertRaises(web.HTTPUnauthorized): + await test_module.ready_middleware(request, handler) + mock_logger.info.assert_called_with( + "Unauthorized access during %s %s: %s", + request.method, + request.path, + handler.side_effect, + ) + + async def test_ready_middleware_http_bad_request(self): + """Test handling of web.HTTPBadRequest and MultitenantManagerError.""" + with mock.patch.object(test_module, "LOGGER", mock.MagicMock()) as mock_logger: + mock_logger.info = mock.MagicMock() + + request = mock.MagicMock( + method="POST", + path="/bad-request", + app=mock.MagicMock(_state={"ready": True}), + ) + + # Test web.HTTPBadRequest + handler = mock.CoroutineMock( + side_effect=web.HTTPBadRequest(reason="Bad request") + ) + with self.assertRaises(web.HTTPBadRequest): + await test_module.ready_middleware(request, handler) + mock_logger.info.assert_called_with( + "Bad request during %s %s: %s", + request.method, + request.path, + handler.side_effect, + ) + + # Test MultitenantManagerError + handler = mock.CoroutineMock( + side_effect=MultitenantManagerError("Multitenant error") + ) + with self.assertRaises(web.HTTPBadRequest): + await test_module.ready_middleware(request, handler) + mock_logger.info.assert_called_with( + "Bad request during %s %s: %s", + request.method, + request.path, + handler.side_effect, + ) + + async def test_ready_middleware_http_not_found(self): + """Test handling of web.HTTPNotFound and StorageNotFoundError.""" + with mock.patch.object(test_module, "LOGGER", mock.MagicMock()) as mock_logger: + mock_logger.info = mock.MagicMock() + + request = mock.MagicMock( + method="GET", + path="/not-found", + app=mock.MagicMock(_state={"ready": True}), + ) + + # Test web.HTTPNotFound + handler = mock.CoroutineMock(side_effect=web.HTTPNotFound(reason="Not found")) + with self.assertRaises(web.HTTPNotFound): + await test_module.ready_middleware(request, handler) + mock_logger.info.assert_called_with( + "Not Found error occurred during %s %s: %s", + request.method, + request.path, + handler.side_effect, + ) + + # Test StorageNotFoundError + handler = mock.CoroutineMock( + side_effect=StorageNotFoundError("Item not found") + ) + with self.assertRaises(web.HTTPNotFound): + await test_module.ready_middleware(request, handler) + mock_logger.info.assert_called_with( + "Not Found error occurred during %s %s: %s", + request.method, + request.path, + handler.side_effect, + ) + + async def test_ready_middleware_http_unprocessable_entity(self): + """Test handling of web.HTTPUnprocessableEntity with nested ValidationError.""" + with mock.patch.object(test_module, "LOGGER", mock.MagicMock()) as mock_logger: + mock_logger.info = mock.MagicMock() + # Mock the extract_validation_error_message function + with mock.patch.object( + test_module, "extract_validation_error_message" + ) as mock_extract: + mock_extract.return_value = {"field": ["Invalid input"]} + + request = mock.MagicMock( + method="POST", + path="/unprocessable", + app=mock.MagicMock(_state={"ready": True}), + ) + + # Create a HTTPUnprocessableEntity exception with a nested ValidationError + validation_error = ValidationError({"field": ["Invalid input"]}) + http_error = web.HTTPUnprocessableEntity(reason="Unprocessable Entity") + http_error.__cause__ = validation_error + + handler = mock.CoroutineMock(side_effect=http_error) + with self.assertRaises(web.HTTPUnprocessableEntity): + await test_module.ready_middleware(request, handler) + mock_extract.assert_called_once_with(http_error) + mock_logger.info.assert_called_with( + "Unprocessable Entity occurred during %s %s: %s", + request.method, + request.path, + mock_extract.return_value, + ) + def get_admin_server( self, settings: Optional[dict] = None, context: Optional[InjectionContext] = None ) -> AdminServer: diff --git a/acapy_agent/askar/profile.py b/acapy_agent/askar/profile.py index 4622c3b629..60803c3a09 100644 --- a/acapy_agent/askar/profile.py +++ b/acapy_agent/askar/profile.py @@ -75,7 +75,7 @@ def init_ledger_pool(self): read_only = bool(self.settings.get("ledger.read_only", False)) socks_proxy = self.settings.get("ledger.socks_proxy") if read_only: - LOGGER.error("Note: setting ledger to read-only mode") + LOGGER.warning("Note: setting ledger to read-only mode") genesis_transactions = self.settings.get("ledger.genesis_transactions") cache = self.context.injector.inject_or(BaseCache) self.ledger_pool = IndyVdrLedgerPool( diff --git a/acapy_agent/askar/profile_anon.py b/acapy_agent/askar/profile_anon.py index 6b526e634a..eea73d876f 100644 --- a/acapy_agent/askar/profile_anon.py +++ b/acapy_agent/askar/profile_anon.py @@ -77,7 +77,7 @@ def init_ledger_pool(self): read_only = bool(self.settings.get("ledger.read_only", False)) socks_proxy = self.settings.get("ledger.socks_proxy") if read_only: - LOGGER.error("Note: setting ledger to read-only mode") + LOGGER.warning("Note: setting ledger to read-only mode") genesis_transactions = self.settings.get("ledger.genesis_transactions") cache = self.context.injector.inject_or(BaseCache) self.ledger_pool = IndyVdrLedgerPool( diff --git a/acapy_agent/protocols/out_of_band/v1_0/manager.py b/acapy_agent/protocols/out_of_band/v1_0/manager.py index bcaa177971..5b0bbf90ec 100644 --- a/acapy_agent/protocols/out_of_band/v1_0/manager.py +++ b/acapy_agent/protocols/out_of_band/v1_0/manager.py @@ -461,7 +461,7 @@ async def handle_use_did_method( did_method = PEER4 if did_peer_4 else PEER2 my_info = await self.oob.fetch_invitation_reuse_did(did_method) if not my_info: - LOGGER.warn("No invitation DID found, creating new DID") + LOGGER.warning("No invitation DID found, creating new DID") if not my_info: did_metadata = ( diff --git a/acapy_agent/utils/extract_validation_error.py b/acapy_agent/utils/extract_validation_error.py new file mode 100644 index 0000000000..f06ecd848a --- /dev/null +++ b/acapy_agent/utils/extract_validation_error.py @@ -0,0 +1,16 @@ +"""Extract validation error messages from nested exceptions.""" + +from aiohttp.web import HTTPUnprocessableEntity +from marshmallow.exceptions import ValidationError + + +def extract_validation_error_message(exc: HTTPUnprocessableEntity) -> str: + """Extract marshmallow error message from a nested UnprocessableEntity exception.""" + visited = set() + current_exc = exc + while current_exc and current_exc not in visited: + visited.add(current_exc) + if isinstance(current_exc, ValidationError): + return current_exc.messages + current_exc = current_exc.__cause__ or current_exc.__context__ + return exc.reason diff --git a/acapy_agent/utils/tests/test_extract_validation_error.py b/acapy_agent/utils/tests/test_extract_validation_error.py new file mode 100644 index 0000000000..b52740d0e5 --- /dev/null +++ b/acapy_agent/utils/tests/test_extract_validation_error.py @@ -0,0 +1,77 @@ +import unittest + +from aiohttp.web import HTTPUnprocessableEntity +from marshmallow.exceptions import ValidationError + +from ...utils.extract_validation_error import extract_validation_error_message + + +class TestExtractValidationErrorMessage(unittest.TestCase): + def test_validation_error_extracted(self): + """Test that the validation error message is extracted when present.""" + validation_error = ValidationError({"field": ["Invalid input"]}) + http_error = HTTPUnprocessableEntity(reason="Unprocessable Entity") + http_error.__cause__ = validation_error + + result = extract_validation_error_message(http_error) + self.assertEqual(result, {"field": ["Invalid input"]}) + + def test_no_validation_error_returns_reason(self): + """Test that the reason is returned when no ValidationError is found.""" + http_error = HTTPUnprocessableEntity(reason="Unprocessable Entity") + + result = extract_validation_error_message(http_error) + self.assertEqual(result, "Unprocessable Entity") + + def test_deeply_nested_validation_error(self): + """Test extraction when ValidationError is nested deeply.""" + validation_error = ValidationError({"field": ["Invalid input"]}) + level_3 = Exception("Level 3") + level_3.__cause__ = validation_error + level_2 = Exception("Level 2") + level_2.__cause__ = level_3 + http_error = HTTPUnprocessableEntity(reason="Unprocessable Entity") + http_error.__cause__ = level_2 + + result = extract_validation_error_message(http_error) + self.assertEqual(result, {"field": ["Invalid input"]}) + + def test_multiple_exceptions_no_validation_error(self): + """Test that reason is returned when no ValidationError is in the chain.""" + level_2 = Exception("Level 2") + level_1 = Exception("Level 1") + level_1.__cause__ = level_2 + http_error = HTTPUnprocessableEntity(reason="Unprocessable Entity") + http_error.__cause__ = level_1 + + result = extract_validation_error_message(http_error) + self.assertEqual(result, "Unprocessable Entity") + + def test_validation_error_in_context(self): + """Test extraction when ValidationError is in __context__ instead of __cause__.""" + validation_error = ValidationError({"field": ["Invalid input"]}) + level_1 = Exception("Level 1") + level_1.__context__ = validation_error + http_error = HTTPUnprocessableEntity(reason="Unprocessable Entity") + http_error.__context__ = level_1 + + result = extract_validation_error_message(http_error) + self.assertEqual(result, {"field": ["Invalid input"]}) + + def test_exception_already_visited(self): + """Test that visited set prevents infinite loops.""" + validation_error = ValidationError({"field": ["Invalid input"]}) + http_error = HTTPUnprocessableEntity(reason="Unprocessable Entity") + # Create a loop in the exception chain + validation_error.__cause__ = http_error + http_error.__cause__ = validation_error + + result = extract_validation_error_message(http_error) + self.assertEqual(result, {"field": ["Invalid input"]}) + + def test_validation_error_as_initial_exception(self): + """Test when the initial exception is a ValidationError.""" + validation_error = ValidationError({"field": ["Invalid input"]}) + + result = extract_validation_error_message(validation_error) + self.assertEqual(result, {"field": ["Invalid input"]})