Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(event-handler): prefixes to strip for custom mappings #579

Merged
merged 12 commits into from
Aug 19, 2021
30 changes: 28 additions & 2 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class ProxyEventType(Enum):
ALBEvent = "ALBEvent"


class CORSConfig(object):
class CORSConfig:
"""CORS Config

Examples
Expand Down Expand Up @@ -265,6 +265,7 @@ def __init__(
cors: Optional[CORSConfig] = None,
debug: Optional[bool] = None,
serializer: Optional[Callable[[Dict], str]] = None,
strip_prefixes: Optional[List[str]] = None,
):
"""
Parameters
Expand All @@ -276,6 +277,11 @@ def __init__(
debug: Optional[bool]
Enables debug mode, by default False. Can be also be enabled by "POWERTOOLS_EVENT_HANDLER_DEBUG"
environment variable
serializer : Callable, optional
function to serialize `obj` to a JSON formatted `str`, by default json.dumps
strip_prefixes: List[str], optional
optional list of prefixes to be removed from the request path before doing the routing. This is often used
with api gateways with multiple custom mappings.
"""
self._proxy_type = proxy_type
self._routes: List[Route] = []
Expand All @@ -285,6 +291,7 @@ def __init__(
self._debug = resolve_truthy_env_var_choice(
env=os.getenv(constants.EVENT_HANDLER_DEBUG_ENV, "false"), choice=debug
)
self._strip_prefixes = strip_prefixes

# Allow for a custom serializer or a concise json serialization
self._serializer = serializer or partial(json.dumps, separators=(",", ":"), cls=Encoder)
Expand Down Expand Up @@ -521,7 +528,7 @@ def _to_proxy_event(self, event: Dict) -> BaseProxyEvent:
def _resolve(self) -> ResponseBuilder:
"""Resolves the response or return the not found response"""
method = self.current_event.http_method.upper()
path = self.current_event.path
path = self._remove_prefix(self.current_event.path)
for route in self._routes:
if method != route.method:
continue
Expand All @@ -533,6 +540,25 @@ def _resolve(self) -> ResponseBuilder:
logger.debug(f"No match found for path {path} and method {method}")
return self._not_found(method)

def _remove_prefix(self, path: str) -> str:
"""Remove the configured prefix from the path"""
if not isinstance(self._strip_prefixes, list):
return path

for prefix in self._strip_prefixes:
if self._path_starts_with(path, prefix):
return path[len(prefix) :]

return path

@staticmethod
def _path_starts_with(path: str, prefix: str):
"""Returns true if the `path` starts with a prefix plus a `/`"""
if not isinstance(prefix, str) or len(prefix) == 0:
return False

return path.startswith(prefix + "/")

def _not_found(self, method: str) -> ResponseBuilder:
"""Called when no matching route was found and includes support for the cors preflight response"""
headers = {}
Expand Down
54 changes: 54 additions & 0 deletions tests/functional/event_handler/test_api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,3 +769,57 @@ def get_color() -> Dict:
body = response["body"]
expected = '{"color": 1, "variations": ["dark", "light"]}'
assert expected == body


@pytest.mark.parametrize(
"path",
[
pytest.param("/pay/foo", id="path matched pay prefix"),
pytest.param("/payment/foo", id="path matched payment prefix"),
pytest.param("/foo", id="path does not start with any of the prefixes"),
],
)
def test_remove_prefix(path: str):
# GIVEN events paths `/pay/foo`, `/payment/foo` or `/foo`
# AND a configured strip_prefixes of `/pay` and `/payment`
app = ApiGatewayResolver(strip_prefixes=["/pay", "/payment"])

@app.get("/pay/foo")
def pay_foo():
raise ValueError("should not be matching")

@app.get("/foo")
def foo():
...

# WHEN calling handler
response = app({"httpMethod": "GET", "path": path}, None)

# THEN a route for `/foo` should be found
assert response["statusCode"] == 200


@pytest.mark.parametrize(
"prefix",
[
pytest.param("/foo", id="String are not supported"),
pytest.param({"/foo"}, id="Sets are not supported"),
pytest.param({"foo": "/foo"}, id="Dicts are not supported"),
pytest.param(tuple("/foo"), id="Tuples are not supported"),
pytest.param([None, 1, "", False], id="List of invalid values"),
],
)
def test_ignore_invalid(prefix):
# GIVEN an invalid prefix
app = ApiGatewayResolver(strip_prefixes=prefix)

@app.get("/foo/status")
def foo():
...

# WHEN calling handler
response = app({"httpMethod": "GET", "path": "/foo/status"}, None)

# THEN a route for `/foo/status` should be found
# so no prefix was stripped from the request path
assert response["statusCode"] == 200