From bc45703beda8155ca094dbe95ed76464afd6bf6d Mon Sep 17 00:00:00 2001 From: Michal Ploski Date: Thu, 23 Mar 2023 14:29:18 +0100 Subject: [PATCH] Refactor appsync resolver --- .../event_handler/appsync.py | 138 ++++++++++++------ .../src/custom_models.py | 3 +- .../handlers/appsync_resolver_handler.py | 13 +- .../functional/event_handler/test_appsync.py | 14 +- 4 files changed, 110 insertions(+), 58 deletions(-) diff --git a/aws_lambda_powertools/event_handler/appsync.py b/aws_lambda_powertools/event_handler/appsync.py index 4132cc3524..7ed80f6880 100644 --- a/aws_lambda_powertools/event_handler/appsync.py +++ b/aws_lambda_powertools/event_handler/appsync.py @@ -1,22 +1,32 @@ import logging from itertools import groupby -from typing import Any, Callable, List, Optional, Type, TypeVar, Union +from typing import Any, Callable, List, Optional, Type, Union from aws_lambda_powertools.utilities.data_classes import AppSyncResolverEvent from aws_lambda_powertools.utilities.typing import LambdaContext logger = logging.getLogger(__name__) -AppSyncResolverEventT = TypeVar("AppSyncResolverEventT", bound=AppSyncResolverEvent) +class RouterContext: + def __init__(self): + super().__init__() + self.context = {} -class BaseRouter: - current_event: Union[AppSyncResolverEventT, List[AppSyncResolverEventT]] # type: ignore[valid-type] - lambda_context: LambdaContext - context: dict + def append_context(self, **additional_context): + """Append key=value data as routing context""" + self.context.update(**additional_context) + def clear_context(self): + """Resets routing context""" + self.context.clear() + + +class ResolverRegistry: def __init__(self): + super().__init__() self._resolvers: dict = {} + self._batch_resolvers: dict = {} def resolver(self, type_name: str = "*", field_name: Optional[str] = None): """Registers the resolver for field_name @@ -29,23 +39,33 @@ def resolver(self, type_name: str = "*", field_name: Optional[str] = None): Field name """ - def register_resolver(func): + def register(func): logger.debug(f"Adding resolver `{func.__name__}` for field `{type_name}.{field_name}`") self._resolvers[f"{type_name}.{field_name}"] = {"func": func} return func - return register_resolver + return register - def append_context(self, **additional_context): - """Append key=value data as routing context""" - self.context.update(**additional_context) + def batch_resolver(self, type_name: str = "*", field_name: Optional[str] = None): + """Registers the resolver for field_name - def clear_context(self): - """Resets routing context""" - self.context.clear() + Parameters + ---------- + type_name : str + Type name + field_name : str + Field name + """ + def register(func): + logger.debug(f"Adding batch resolver `{func.__name__}` for field `{type_name}.{field_name}`") + self._batch_resolvers[f"{type_name}.{field_name}"] = {"func": func} + return func -class AppSyncResolver(BaseRouter): + return register + + +class AppSyncResolver(ResolverRegistry, RouterContext): """ AppSync resolver decorator @@ -78,16 +98,20 @@ def common_field() -> str: def __init__(self): super().__init__() - self.context = {} # early init as customers might add context before event resolution + self.current_batch_event: List[AppSyncResolverEvent] = [] + self.current_event: Optional[AppSyncResolverEvent] = None def resolve( - self, event: dict, context: LambdaContext, data_model: Type[AppSyncResolverEvent] = AppSyncResolverEvent + self, + event: Union[dict, List[dict]], + context: LambdaContext, + data_model: Type[AppSyncResolverEvent] = AppSyncResolverEvent, ) -> Any: """Resolve field_name Parameters ---------- - event : dict + event : dict | List[dict] Lambda event context : LambdaContext Lambda context @@ -152,33 +176,38 @@ def lambda_handler(event, context): ValueError If we could not find a field resolver """ - # Maintenance: revisit generics/overload to fix [attr-defined] in mypy usage - - BaseRouter.lambda_context = context - - # If event is a list it means that AppSync sent batch request - if isinstance(event, list): - event_groups = [ - {"field_name": field_name, "events": list(events)} - for field_name, events in groupby(event, key=lambda x: x["info"]["fieldName"]) - ] - if len(event_groups) > 1: - ValueError("batch with different field names. It shouldn't happen!") - - appconfig_events = [data_model(event) for event in event_groups[0]["events"]] - BaseRouter.current_event = appconfig_events - resolver = self._get_resolver(appconfig_events[0].type_name, event_groups[0]["field_name"]) - response = resolver() - else: - appconfig_event = data_model(event) - BaseRouter.current_event = appconfig_event - resolver = self._get_resolver(appconfig_event.type_name, appconfig_event.field_name) - response = resolver(**appconfig_event.arguments) + self.lambda_context = context + + response = ( + self._call_batch_resolver(event, data_model) + if isinstance(event, list) + else self._call_resolver(event, data_model) + ) self.clear_context() return response + def _call_resolver(self, event: dict, data_model: Type[AppSyncResolverEvent]) -> Any: + self.current_event = data_model(event) + resolver = self._get_resolver(self.current_event.type_name, self.current_event.field_name) + return resolver(**self.current_event.arguments) + + def _call_batch_resolver(self, event: List[dict], data_model: Type[AppSyncResolverEvent]) -> List[Any]: + event_groups = [ + {"field_name": field_name, "events": list(events)} + for field_name, events in groupby(event, key=lambda x: x["info"]["fieldName"]) + ] + if len(event_groups) > 1: + ValueError("batch with different field names. It shouldn't happen!") + + self.current_batch_event = [data_model(event) for event in event_groups[0]["events"]] + resolver = self._get_batch_resolver( + self.current_batch_event[0].type_name, self.current_batch_event[0].field_name + ) + + return [resolver(event=appconfig_event) for appconfig_event in self.current_batch_event] + def _get_resolver(self, type_name: str, field_name: str) -> Callable: """Get resolver for field_name @@ -200,8 +229,32 @@ def _get_resolver(self, type_name: str, field_name: str) -> Callable: raise ValueError(f"No resolver found for '{full_name}'") return resolver["func"] + def _get_batch_resolver(self, type_name: str, field_name: str) -> Callable: + """Get resolver for field_name + + Parameters + ---------- + type_name : str + Type name + field_name : str + Field name + + Returns + ------- + Callable + callable function and configuration + """ + full_name = f"{type_name}.{field_name}" + resolver = self._batch_resolvers.get(full_name, self._batch_resolvers.get(f"*.{field_name}")) + if not resolver: + raise ValueError(f"No batch resolver found for '{full_name}'") + return resolver["func"] + def __call__( - self, event: dict, context: LambdaContext, data_model: Type[AppSyncResolverEvent] = AppSyncResolverEvent + self, + event: Union[dict, List[dict]], + context: LambdaContext, + data_model: Type[AppSyncResolverEvent] = AppSyncResolverEvent, ) -> Any: """Implicit lambda handler which internally calls `resolve`""" return self.resolve(event, context, data_model) @@ -222,7 +275,6 @@ def include_router(self, router: "Router") -> None: self._resolvers.update(router._resolvers) -class Router(BaseRouter): +class Router(RouterContext, ResolverRegistry): def __init__(self): super().__init__() - self.context = {} # early init as customers might add context before event resolution diff --git a/examples/event_handler_graphql/src/custom_models.py b/examples/event_handler_graphql/src/custom_models.py index 6d82e1ba9b..84982b0f81 100644 --- a/examples/event_handler_graphql/src/custom_models.py +++ b/examples/event_handler_graphql/src/custom_models.py @@ -42,7 +42,8 @@ def api_key(self) -> str: @app.resolver(type_name="Query", field_name="listLocations") def list_locations(page: int = 0, size: int = 10) -> List[Location]: # additional properties/methods will now be available under current_event - logger.debug(f"Request country origin: {app.current_event.country_viewer}") # type: ignore[attr-defined] + if app.current_event: + logger.debug(f"Request country origin: {app.current_event.country_viewer}") # type: ignore[attr-defined] return [{"id": scalar_types_utils.make_id(), "name": "Perry, James and Carroll"}] diff --git a/tests/e2e/event_handler/handlers/appsync_resolver_handler.py b/tests/e2e/event_handler/handlers/appsync_resolver_handler.py index 1a67f714e3..c904a86e49 100644 --- a/tests/e2e/event_handler/handlers/appsync_resolver_handler.py +++ b/tests/e2e/event_handler/handlers/appsync_resolver_handler.py @@ -1,8 +1,9 @@ -from typing import List +from typing import List, Optional from pydantic import BaseModel from aws_lambda_powertools.event_handler import AppSyncResolver +from aws_lambda_powertools.utilities.data_classes import AppSyncResolverEvent from aws_lambda_powertools.utilities.typing import LambdaContext app = AppSyncResolver() @@ -86,13 +87,9 @@ def all_posts() -> List[dict]: return list(posts.values()) -@app.resolver(type_name="Post", field_name="relatedPosts") -def related_posts() -> List[dict]: - posts = [] - for resolver_event in app.current_event: - if resolver_event.source: - posts.append(posts_related[resolver_event.source["post_id"]]) - return posts +@app.batch_resolver(type_name="Post", field_name="relatedPosts") +def related_posts(event: AppSyncResolverEvent) -> Optional[list]: + return posts_related[event.source["post_id"]] if event.source else None def lambda_handler(event, context: LambdaContext) -> dict: diff --git a/tests/functional/event_handler/test_appsync.py b/tests/functional/event_handler/test_appsync.py index 9cb833a161..54a389dd6b 100644 --- a/tests/functional/event_handler/test_appsync.py +++ b/tests/functional/event_handler/test_appsync.py @@ -1,5 +1,6 @@ import asyncio import sys +from typing import Optional import pytest @@ -199,22 +200,23 @@ def test_resolve_batch_processing(): "fieldName": "listLocations", "arguments": {}, "source": { - "id": "3", + "id": [3, 4], }, }, ] app = AppSyncResolver() - @app.resolver(field_name="listLocations") - def create_something(): # noqa AA03 VNE003 - return [event.source["id"] for event in app.current_event] + @app.batch_resolver(field_name="listLocations") + def create_something(event: AppSyncResolverEvent) -> Optional[list]: # noqa AA03 VNE003 + return event.source["id"] if event.source else None # Call the implicit handler result = app.resolve(event, LambdaContext()) - assert result == ["1", "2", "3"] + assert result == [appsync_event["source"]["id"] for appsync_event in event] - assert len(app.current_event) == len(event) + assert app.current_batch_event and len(app.current_batch_event) == len(event) + assert not app.current_event def test_resolver_include_resolver():