From 8b01fc5f02390d4cab2cd64b72381e2d785aec4f Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Fri, 29 Oct 2021 02:58:57 -0700 Subject: [PATCH] feat(appsync): add Router to allow large resolver composition (#776) --- .../event_handler/appsync.py | 75 ++++++++++++------- .../functional/event_handler/test_appsync.py | 27 +++++++ 2 files changed, 75 insertions(+), 27 deletions(-) diff --git a/aws_lambda_powertools/event_handler/appsync.py b/aws_lambda_powertools/event_handler/appsync.py index 69b90c4cbb6..6a4bf989169 100644 --- a/aws_lambda_powertools/event_handler/appsync.py +++ b/aws_lambda_powertools/event_handler/appsync.py @@ -1,4 +1,5 @@ import logging +from abc import ABC from typing import Any, Callable, Optional, Type, TypeVar from aws_lambda_powertools.utilities.data_classes import AppSyncResolverEvent @@ -9,7 +10,33 @@ AppSyncResolverEventT = TypeVar("AppSyncResolverEventT", bound=AppSyncResolverEvent) -class AppSyncResolver: +class BaseRouter(ABC): + current_event: AppSyncResolverEventT # type: ignore[valid-type] + lambda_context: LambdaContext + + def __init__(self): + self._resolvers: dict = {} + + def resolver(self, type_name: str = "*", field_name: Optional[str] = None): + """Registers the resolver for field_name + + Parameters + ---------- + type_name : str + Type name + field_name : str + Field name + """ + + def register_resolver(func): + logger.debug(f"Adding resolver `{func.__name__}` for field `{type_name}.{field_name}`") + self._resolvers[f"{type_name}.{field_name}"] = {"func": func} + return func + + return register_resolver + + +class AppSyncResolver(BaseRouter): """ AppSync resolver decorator @@ -40,29 +67,8 @@ def common_field() -> str: return str(uuid.uuid4()) """ - current_event: AppSyncResolverEventT # type: ignore[valid-type] - lambda_context: LambdaContext - def __init__(self): - self._resolvers: dict = {} - - def resolver(self, type_name: str = "*", field_name: Optional[str] = None): - """Registers the resolver for field_name - - Parameters - ---------- - type_name : str - Type name - field_name : str - Field name - """ - - def register_resolver(func): - logger.debug(f"Adding resolver `{func.__name__}` for field `{type_name}.{field_name}`") - self._resolvers[f"{type_name}.{field_name}"] = {"func": func} - return func - - return register_resolver + super().__init__() def resolve( self, event: dict, context: LambdaContext, data_model: Type[AppSyncResolverEvent] = AppSyncResolverEvent @@ -136,10 +142,10 @@ def lambda_handler(event, context): ValueError If we could not find a field resolver """ - self.current_event = data_model(event) - self.lambda_context = context - resolver = self._get_resolver(self.current_event.type_name, self.current_event.field_name) - return resolver(**self.current_event.arguments) + BaseRouter.current_event = data_model(event) + BaseRouter.lambda_context = context + resolver = self._get_resolver(BaseRouter.current_event.type_name, BaseRouter.current_event.field_name) + return resolver(**BaseRouter.current_event.arguments) def _get_resolver(self, type_name: str, field_name: str) -> Callable: """Get resolver for field_name @@ -167,3 +173,18 @@ def __call__( ) -> Any: """Implicit lambda handler which internally calls `resolve`""" return self.resolve(event, context, data_model) + + def include_router(self, router: "Router") -> None: + """Adds all resolvers defined in a router + + Parameters + ---------- + router : Router + A router containing a dict of field resolvers + """ + self._resolvers.update(router._resolvers) + + +class Router(BaseRouter): + def __init__(self): + super().__init__() diff --git a/tests/functional/event_handler/test_appsync.py b/tests/functional/event_handler/test_appsync.py index 26a3ffdcb1f..79173e55825 100644 --- a/tests/functional/event_handler/test_appsync.py +++ b/tests/functional/event_handler/test_appsync.py @@ -4,6 +4,7 @@ import pytest from aws_lambda_powertools.event_handler import AppSyncResolver +from aws_lambda_powertools.event_handler.appsync import Router from aws_lambda_powertools.utilities.data_classes import AppSyncResolverEvent from aws_lambda_powertools.utilities.typing import LambdaContext from tests.functional.utils import load_event @@ -161,3 +162,29 @@ def create_something(id: str): # noqa AA03 VNE003 assert result == "my identifier" assert app.current_event.country_viewer == "US" + + +def test_resolver_include_resolver(): + # GIVEN + app = AppSyncResolver() + router = Router() + + @router.resolver(type_name="Query", field_name="listLocations") + def get_locations(name: str): + return "get_locations#" + name + + @app.resolver(field_name="listLocations2") + def get_locations2(name: str): + return "get_locations2#" + name + + app.include_router(router) + + # WHEN + mock_event1 = {"typeName": "Query", "fieldName": "listLocations", "arguments": {"name": "value"}} + mock_event2 = {"typeName": "Query", "fieldName": "listLocations2", "arguments": {"name": "value"}} + result1 = app.resolve(mock_event1, LambdaContext()) + result2 = app.resolve(mock_event2, LambdaContext()) + + # THEN + assert result1 == "get_locations#value" + assert result2 == "get_locations2#value"