Skip to content

Commit

Permalink
Refactor appsync resolver
Browse files Browse the repository at this point in the history
  • Loading branch information
Michal Ploski committed Mar 23, 2023
1 parent d0fe867 commit bc45703
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 58 deletions.
138 changes: 95 additions & 43 deletions aws_lambda_powertools/event_handler/appsync.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,32 @@
import logging
from itertools import groupby
from typing import Any, Callable, List, Optional, Type, TypeVar, Union
from typing import Any, Callable, List, Optional, Type, Union

from aws_lambda_powertools.utilities.data_classes import AppSyncResolverEvent
from aws_lambda_powertools.utilities.typing import LambdaContext

logger = logging.getLogger(__name__)

AppSyncResolverEventT = TypeVar("AppSyncResolverEventT", bound=AppSyncResolverEvent)

class RouterContext:
def __init__(self):
super().__init__()
self.context = {}

class BaseRouter:
current_event: Union[AppSyncResolverEventT, List[AppSyncResolverEventT]] # type: ignore[valid-type]
lambda_context: LambdaContext
context: dict
def append_context(self, **additional_context):
"""Append key=value data as routing context"""
self.context.update(**additional_context)

def clear_context(self):
"""Resets routing context"""
self.context.clear()


class ResolverRegistry:
def __init__(self):
super().__init__()
self._resolvers: dict = {}
self._batch_resolvers: dict = {}

def resolver(self, type_name: str = "*", field_name: Optional[str] = None):
"""Registers the resolver for field_name
Expand All @@ -29,23 +39,33 @@ def resolver(self, type_name: str = "*", field_name: Optional[str] = None):
Field name
"""

def register_resolver(func):
def register(func):
logger.debug(f"Adding resolver `{func.__name__}` for field `{type_name}.{field_name}`")
self._resolvers[f"{type_name}.{field_name}"] = {"func": func}
return func

return register_resolver
return register

def append_context(self, **additional_context):
"""Append key=value data as routing context"""
self.context.update(**additional_context)
def batch_resolver(self, type_name: str = "*", field_name: Optional[str] = None):
"""Registers the resolver for field_name
def clear_context(self):
"""Resets routing context"""
self.context.clear()
Parameters
----------
type_name : str
Type name
field_name : str
Field name
"""

def register(func):
logger.debug(f"Adding batch resolver `{func.__name__}` for field `{type_name}.{field_name}`")
self._batch_resolvers[f"{type_name}.{field_name}"] = {"func": func}
return func

class AppSyncResolver(BaseRouter):
return register


class AppSyncResolver(ResolverRegistry, RouterContext):
"""
AppSync resolver decorator
Expand Down Expand Up @@ -78,16 +98,20 @@ def common_field() -> str:

def __init__(self):
super().__init__()
self.context = {} # early init as customers might add context before event resolution
self.current_batch_event: List[AppSyncResolverEvent] = []
self.current_event: Optional[AppSyncResolverEvent] = None

def resolve(
self, event: dict, context: LambdaContext, data_model: Type[AppSyncResolverEvent] = AppSyncResolverEvent
self,
event: Union[dict, List[dict]],
context: LambdaContext,
data_model: Type[AppSyncResolverEvent] = AppSyncResolverEvent,
) -> Any:
"""Resolve field_name
Parameters
----------
event : dict
event : dict | List[dict]
Lambda event
context : LambdaContext
Lambda context
Expand Down Expand Up @@ -152,33 +176,38 @@ def lambda_handler(event, context):
ValueError
If we could not find a field resolver
"""
# Maintenance: revisit generics/overload to fix [attr-defined] in mypy usage

BaseRouter.lambda_context = context

# If event is a list it means that AppSync sent batch request
if isinstance(event, list):
event_groups = [
{"field_name": field_name, "events": list(events)}
for field_name, events in groupby(event, key=lambda x: x["info"]["fieldName"])
]
if len(event_groups) > 1:
ValueError("batch with different field names. It shouldn't happen!")

appconfig_events = [data_model(event) for event in event_groups[0]["events"]]
BaseRouter.current_event = appconfig_events
resolver = self._get_resolver(appconfig_events[0].type_name, event_groups[0]["field_name"])
response = resolver()
else:
appconfig_event = data_model(event)
BaseRouter.current_event = appconfig_event
resolver = self._get_resolver(appconfig_event.type_name, appconfig_event.field_name)
response = resolver(**appconfig_event.arguments)

self.lambda_context = context

response = (
self._call_batch_resolver(event, data_model)
if isinstance(event, list)
else self._call_resolver(event, data_model)
)
self.clear_context()

return response

def _call_resolver(self, event: dict, data_model: Type[AppSyncResolverEvent]) -> Any:
self.current_event = data_model(event)
resolver = self._get_resolver(self.current_event.type_name, self.current_event.field_name)
return resolver(**self.current_event.arguments)

def _call_batch_resolver(self, event: List[dict], data_model: Type[AppSyncResolverEvent]) -> List[Any]:
event_groups = [
{"field_name": field_name, "events": list(events)}
for field_name, events in groupby(event, key=lambda x: x["info"]["fieldName"])
]
if len(event_groups) > 1:
ValueError("batch with different field names. It shouldn't happen!")

self.current_batch_event = [data_model(event) for event in event_groups[0]["events"]]
resolver = self._get_batch_resolver(
self.current_batch_event[0].type_name, self.current_batch_event[0].field_name
)

return [resolver(event=appconfig_event) for appconfig_event in self.current_batch_event]

def _get_resolver(self, type_name: str, field_name: str) -> Callable:
"""Get resolver for field_name
Expand All @@ -200,8 +229,32 @@ def _get_resolver(self, type_name: str, field_name: str) -> Callable:
raise ValueError(f"No resolver found for '{full_name}'")
return resolver["func"]

def _get_batch_resolver(self, type_name: str, field_name: str) -> Callable:
"""Get resolver for field_name
Parameters
----------
type_name : str
Type name
field_name : str
Field name
Returns
-------
Callable
callable function and configuration
"""
full_name = f"{type_name}.{field_name}"
resolver = self._batch_resolvers.get(full_name, self._batch_resolvers.get(f"*.{field_name}"))
if not resolver:
raise ValueError(f"No batch resolver found for '{full_name}'")
return resolver["func"]

def __call__(
self, event: dict, context: LambdaContext, data_model: Type[AppSyncResolverEvent] = AppSyncResolverEvent
self,
event: Union[dict, List[dict]],
context: LambdaContext,
data_model: Type[AppSyncResolverEvent] = AppSyncResolverEvent,
) -> Any:
"""Implicit lambda handler which internally calls `resolve`"""
return self.resolve(event, context, data_model)
Expand All @@ -222,7 +275,6 @@ def include_router(self, router: "Router") -> None:
self._resolvers.update(router._resolvers)


class Router(BaseRouter):
class Router(RouterContext, ResolverRegistry):
def __init__(self):
super().__init__()
self.context = {} # early init as customers might add context before event resolution
3 changes: 2 additions & 1 deletion examples/event_handler_graphql/src/custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def api_key(self) -> str:
@app.resolver(type_name="Query", field_name="listLocations")
def list_locations(page: int = 0, size: int = 10) -> List[Location]:
# additional properties/methods will now be available under current_event
logger.debug(f"Request country origin: {app.current_event.country_viewer}") # type: ignore[attr-defined]
if app.current_event:
logger.debug(f"Request country origin: {app.current_event.country_viewer}") # type: ignore[attr-defined]
return [{"id": scalar_types_utils.make_id(), "name": "Perry, James and Carroll"}]


Expand Down
13 changes: 5 additions & 8 deletions tests/e2e/event_handler/handlers/appsync_resolver_handler.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import List
from typing import List, Optional

from pydantic import BaseModel

from aws_lambda_powertools.event_handler import AppSyncResolver
from aws_lambda_powertools.utilities.data_classes import AppSyncResolverEvent
from aws_lambda_powertools.utilities.typing import LambdaContext

app = AppSyncResolver()
Expand Down Expand Up @@ -86,13 +87,9 @@ def all_posts() -> List[dict]:
return list(posts.values())


@app.resolver(type_name="Post", field_name="relatedPosts")
def related_posts() -> List[dict]:
posts = []
for resolver_event in app.current_event:
if resolver_event.source:
posts.append(posts_related[resolver_event.source["post_id"]])
return posts
@app.batch_resolver(type_name="Post", field_name="relatedPosts")
def related_posts(event: AppSyncResolverEvent) -> Optional[list]:
return posts_related[event.source["post_id"]] if event.source else None


def lambda_handler(event, context: LambdaContext) -> dict:
Expand Down
14 changes: 8 additions & 6 deletions tests/functional/event_handler/test_appsync.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import sys
from typing import Optional

import pytest

Expand Down Expand Up @@ -199,22 +200,23 @@ def test_resolve_batch_processing():
"fieldName": "listLocations",
"arguments": {},
"source": {
"id": "3",
"id": [3, 4],
},
},
]

app = AppSyncResolver()

@app.resolver(field_name="listLocations")
def create_something(): # noqa AA03 VNE003
return [event.source["id"] for event in app.current_event]
@app.batch_resolver(field_name="listLocations")
def create_something(event: AppSyncResolverEvent) -> Optional[list]: # noqa AA03 VNE003
return event.source["id"] if event.source else None

# Call the implicit handler
result = app.resolve(event, LambdaContext())
assert result == ["1", "2", "3"]
assert result == [appsync_event["source"]["id"] for appsync_event in event]

assert len(app.current_event) == len(event)
assert app.current_batch_event and len(app.current_batch_event) == len(event)
assert not app.current_event


def test_resolver_include_resolver():
Expand Down

0 comments on commit bc45703

Please sign in to comment.