Skip to content

Commit

Permalink
feat(appsync): add Router to allow large resolver composition (#776)
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Brewer authored Oct 29, 2021
1 parent bb8e3b6 commit 8b01fc5
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 27 deletions.
75 changes: 48 additions & 27 deletions aws_lambda_powertools/event_handler/appsync.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from abc import ABC
from typing import Any, Callable, Optional, Type, TypeVar

from aws_lambda_powertools.utilities.data_classes import AppSyncResolverEvent
Expand All @@ -9,7 +10,33 @@
AppSyncResolverEventT = TypeVar("AppSyncResolverEventT", bound=AppSyncResolverEvent)


class AppSyncResolver:
class BaseRouter(ABC):
current_event: AppSyncResolverEventT # type: ignore[valid-type]
lambda_context: LambdaContext

def __init__(self):
self._resolvers: dict = {}

def resolver(self, type_name: str = "*", field_name: Optional[str] = None):
"""Registers the resolver for field_name
Parameters
----------
type_name : str
Type name
field_name : str
Field name
"""

def register_resolver(func):
logger.debug(f"Adding resolver `{func.__name__}` for field `{type_name}.{field_name}`")
self._resolvers[f"{type_name}.{field_name}"] = {"func": func}
return func

return register_resolver


class AppSyncResolver(BaseRouter):
"""
AppSync resolver decorator
Expand Down Expand Up @@ -40,29 +67,8 @@ def common_field() -> str:
return str(uuid.uuid4())
"""

current_event: AppSyncResolverEventT # type: ignore[valid-type]
lambda_context: LambdaContext

def __init__(self):
self._resolvers: dict = {}

def resolver(self, type_name: str = "*", field_name: Optional[str] = None):
"""Registers the resolver for field_name
Parameters
----------
type_name : str
Type name
field_name : str
Field name
"""

def register_resolver(func):
logger.debug(f"Adding resolver `{func.__name__}` for field `{type_name}.{field_name}`")
self._resolvers[f"{type_name}.{field_name}"] = {"func": func}
return func

return register_resolver
super().__init__()

def resolve(
self, event: dict, context: LambdaContext, data_model: Type[AppSyncResolverEvent] = AppSyncResolverEvent
Expand Down Expand Up @@ -136,10 +142,10 @@ def lambda_handler(event, context):
ValueError
If we could not find a field resolver
"""
self.current_event = data_model(event)
self.lambda_context = context
resolver = self._get_resolver(self.current_event.type_name, self.current_event.field_name)
return resolver(**self.current_event.arguments)
BaseRouter.current_event = data_model(event)
BaseRouter.lambda_context = context
resolver = self._get_resolver(BaseRouter.current_event.type_name, BaseRouter.current_event.field_name)
return resolver(**BaseRouter.current_event.arguments)

def _get_resolver(self, type_name: str, field_name: str) -> Callable:
"""Get resolver for field_name
Expand Down Expand Up @@ -167,3 +173,18 @@ def __call__(
) -> Any:
"""Implicit lambda handler which internally calls `resolve`"""
return self.resolve(event, context, data_model)

def include_router(self, router: "Router") -> None:
"""Adds all resolvers defined in a router
Parameters
----------
router : Router
A router containing a dict of field resolvers
"""
self._resolvers.update(router._resolvers)


class Router(BaseRouter):
def __init__(self):
super().__init__()
27 changes: 27 additions & 0 deletions tests/functional/event_handler/test_appsync.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest

from aws_lambda_powertools.event_handler import AppSyncResolver
from aws_lambda_powertools.event_handler.appsync import Router
from aws_lambda_powertools.utilities.data_classes import AppSyncResolverEvent
from aws_lambda_powertools.utilities.typing import LambdaContext
from tests.functional.utils import load_event
Expand Down Expand Up @@ -161,3 +162,29 @@ def create_something(id: str): # noqa AA03 VNE003
assert result == "my identifier"

assert app.current_event.country_viewer == "US"


def test_resolver_include_resolver():
# GIVEN
app = AppSyncResolver()
router = Router()

@router.resolver(type_name="Query", field_name="listLocations")
def get_locations(name: str):
return "get_locations#" + name

@app.resolver(field_name="listLocations2")
def get_locations2(name: str):
return "get_locations2#" + name

app.include_router(router)

# WHEN
mock_event1 = {"typeName": "Query", "fieldName": "listLocations", "arguments": {"name": "value"}}
mock_event2 = {"typeName": "Query", "fieldName": "listLocations2", "arguments": {"name": "value"}}
result1 = app.resolve(mock_event1, LambdaContext())
result2 = app.resolve(mock_event2, LambdaContext())

# THEN
assert result1 == "get_locations#value"
assert result2 == "get_locations2#value"

0 comments on commit 8b01fc5

Please sign in to comment.