Skip to content

Commit

Permalink
feat(data_classes): return empty dict or list instead of None
Browse files Browse the repository at this point in the history
This simplifies the code internally and also for users.

Also wrap all headers in CaseInsensitiveDict from requests.

These changes replace the need of utility functions like
get_header_value, get_query_string_value or
get_multi_value_query_string_values, which are removed.
  • Loading branch information
ericbn committed Jun 22, 2024
1 parent 0b29cef commit 722d27f
Show file tree
Hide file tree
Showing 51 changed files with 255 additions and 889 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

## Features

* **data_classes:** return empty dict or list in Event Source Data Classes instead of None, and return case-insensitive dict for headrs. ([#2605](https://github.com/aws-powertools/powertools-lambda-python/issues/2605))
* **event_source:** add CloudFormationCustomResourceEvent data class. ([#4342](https://github.com/aws-powertools/powertools-lambda-python/issues/4342))

## Maintenance
Expand Down
6 changes: 1 addition & 5 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,11 +817,7 @@ def _has_compression_enabled(
bool
True if compression is enabled and the "gzip" encoding is accepted, False otherwise.
"""
encoding: str = event.get_header_value(
name="accept-encoding",
default_value="",
case_sensitive=False,
) # noqa: E501
encoding = event.headers.get("accept-encoding", "")
if "gzip" in encoding:
if response_compression is not None:
return response_compression # e.g., Response(compress=False/True))
Expand Down
2 changes: 1 addition & 1 deletion aws_lambda_powertools/event_handler/appsync.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def handler(event, context: LambdaContext):
class MyCustomModel(AppSyncResolverEvent):
@property
def country_viewer(self) -> str:
def country_viewer(self):
return self.request_headers.get("cloudfront-viewer-country")
Expand Down
5 changes: 1 addition & 4 deletions aws_lambda_powertools/event_handler/middlewares/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,7 @@ def __init__(self, header: str):
def handler(self, app: APIGatewayRestResolver, next_middleware: NextMiddleware) -> Response:
# BEFORE logic
request_id = app.current_event.request_context.request_id
correlation_id = app.current_event.get_header_value(
name=self.header,
default_value=request_id,
)
correlation_id = app.current_event.headers.get(self.header, request_id)
# Call next middleware or route handler ('/todos')
response = next_middleware(app)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import logging
from copy import deepcopy
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple
from typing import Any, Callable, Dict, List, Mapping, MutableMapping, Optional, Sequence, Tuple

from pydantic import BaseModel

Expand Down Expand Up @@ -237,8 +237,8 @@ def _get_body(self, app: EventHandlerInstance) -> Dict[str, Any]:
Get the request body from the event, and parse it as JSON.
"""

content_type_value = app.current_event.get_header_value("content-type")
if not content_type_value or content_type_value.strip().startswith("application/json"):
content_type = app.current_event.headers.get("content-type")
if not content_type or content_type.strip().startswith("application/json"):
try:
return app.current_event.json_body
except json.JSONDecodeError as e:
Expand Down Expand Up @@ -410,7 +410,7 @@ def _normalize_multi_query_string_with_param(
return resolved_query_string


def _normalize_multi_header_values_with_param(headers: Dict[str, Any], params: Sequence[ModelField]):
def _normalize_multi_header_values_with_param(headers: MutableMapping[str, Any], params: Sequence[ModelField]):
"""
Extract and normalize resolved_headers_field
Expand Down
16 changes: 4 additions & 12 deletions aws_lambda_powertools/event_handler/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from typing import Any, Dict

from aws_lambda_powertools.utilities.data_classes.shared_functions import get_header_value
from typing import Any, Mapping, Optional


class _FrozenDict(dict):
Expand All @@ -18,25 +16,19 @@ def __hash__(self):
return hash(frozenset(self.keys()))


def extract_origin_header(resolver_headers: Dict[str, Any]):
def extract_origin_header(resolved_headers: Mapping[str, Any]) -> Optional[str]:
"""
Extracts the 'origin' or 'Origin' header from the provided resolver headers.
The 'origin' or 'Origin' header can be either a single header or a multi-header.
Args:
resolver_headers (Dict): A dictionary containing the headers.
resolved_headers (Mapping): A dictionary containing the headers.
Returns:
Optional[str]: The value(s) of the origin header or None.
"""
resolved_header = get_header_value(
headers=resolver_headers,
name="origin",
default_value=None,
case_sensitive=False,
)
resolved_header = resolved_headers.get("origin")
if isinstance(resolved_header, list):
return resolved_header[0]

return resolved_header
24 changes: 8 additions & 16 deletions aws_lambda_powertools/utilities/data_classes/alb_event.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, MutableMapping

from requests.structures import CaseInsensitiveDict

from aws_lambda_powertools.shared.headers_serializer import (
BaseHeadersSerializer,
Expand Down Expand Up @@ -37,25 +39,15 @@ def multi_value_query_string_parameters(self) -> Dict[str, List[str]]:

@property
def resolved_query_string_parameters(self) -> Dict[str, List[str]]:
if self.multi_value_query_string_parameters:
return self.multi_value_query_string_parameters

return super().resolved_query_string_parameters
return self.multi_value_query_string_parameters or super().resolved_query_string_parameters

@property
def resolved_headers_field(self) -> Dict[str, Any]:
headers: Dict[str, Any] = {}

if self.multi_value_headers:
headers = self.multi_value_headers
else:
headers = self.headers

return {key.lower(): value for key, value in headers.items()}
def multi_value_headers(self) -> MutableMapping[str, List[str]]:
return CaseInsensitiveDict(self.get("multiValueHeaders"))

@property
def multi_value_headers(self) -> Optional[Dict[str, List[str]]]:
return self.get("multiValueHeaders")
def resolved_headers_field(self) -> MutableMapping[str, Any]:
return self.multi_value_headers or self.headers

def header_serializer(self) -> BaseHeadersSerializer:
# When using the ALB integration, the `multiValueHeaders` feature can be disabled (default) or enabled.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import enum
import re
from typing import Any, Dict, List, Optional, overload
from typing import Any, Dict, List, MutableMapping, Optional

from requests.structures import CaseInsensitiveDict

from aws_lambda_powertools.utilities.data_classes.common import (
BaseRequestContext,
BaseRequestContextV2,
DictWrapper,
)
from aws_lambda_powertools.utilities.data_classes.shared_functions import (
get_header_value,
)


class APIGatewayRouteArn:
Expand Down Expand Up @@ -143,8 +142,8 @@ def http_method(self) -> str:
return self["httpMethod"]

@property
def headers(self) -> Dict[str, str]:
return self["headers"]
def headers(self) -> MutableMapping[str, str]:
return CaseInsensitiveDict(self["headers"])

@property
def query_string_parameters(self) -> Dict[str, str]:
Expand All @@ -162,45 +161,6 @@ def stage_variables(self) -> Dict[str, str]:
def request_context(self) -> BaseRequestContext:
return BaseRequestContext(self._data)

@overload
def get_header_value(
self,
name: str,
default_value: str,
case_sensitive: bool = False,
) -> str: ...

@overload
def get_header_value(
self,
name: str,
default_value: Optional[str] = None,
case_sensitive: bool = False,
) -> Optional[str]: ...

def get_header_value(
self,
name: str,
default_value: Optional[str] = None,
case_sensitive: bool = False,
) -> Optional[str]:
"""Get header value by name
Parameters
----------
name: str
Header name
default_value: str, optional
Default value if no value was found by name
case_sensitive: bool
Whether to use a case-sensitive look up
Returns
-------
str, optional
Header value
"""
return get_header_value(self.headers, name, default_value, case_sensitive)


class APIGatewayAuthorizerEventV2(DictWrapper):
"""API Gateway Authorizer Event Format 2.0
Expand Down Expand Up @@ -234,14 +194,14 @@ def parsed_arn(self) -> APIGatewayRouteArn:
return parse_api_gateway_arn(self.route_arn)

@property
def identity_source(self) -> Optional[List[str]]:
def identity_source(self) -> List[str]:
"""The identity source for which authorization is requested.

For a REQUEST authorizer, this is optional. The value is a set of one or more mapping expressions of the
specified request parameters. The identity source can be headers, query string parameters, stage variables,
and context parameters.
"""
return self.get("identitySource")
return self.get("identitySource") or []

@property
def route_key(self) -> str:
Expand All @@ -263,9 +223,9 @@ def cookies(self) -> List[str]:
return self["cookies"]

@property
def headers(self) -> Dict[str, str]:
def headers(self) -> MutableMapping[str, str]:
"""Http headers"""
return self["headers"]
return CaseInsensitiveDict(self["headers"])

@property
def query_string_parameters(self) -> Dict[str, str]:
Expand All @@ -276,46 +236,12 @@ def request_context(self) -> BaseRequestContextV2:
return BaseRequestContextV2(self._data)

@property
def path_parameters(self) -> Optional[Dict[str, str]]:
return self.get("pathParameters")
def path_parameters(self) -> Dict[str, str]:
return self.get("pathParameters") or {}

@property
def stage_variables(self) -> Optional[Dict[str, str]]:
return self.get("stageVariables")

@overload
def get_header_value(self, name: str, default_value: str, case_sensitive: bool = False) -> str: ...

@overload
def get_header_value(
self,
name: str,
default_value: Optional[str] = None,
case_sensitive: bool = False,
) -> Optional[str]: ...

def get_header_value(
self,
name: str,
default_value: Optional[str] = None,
case_sensitive: bool = False,
) -> Optional[str]:
"""Get header value by name
Parameters
----------
name: str
Header name
default_value: str, optional
Default value if no value was found by name
case_sensitive: bool
Whether to use a case-sensitive look up
Returns
-------
str, optional
Header value
"""
return get_header_value(self.headers, name, default_value, case_sensitive)
def stage_variables(self) -> Dict[str, str]:
return self.get("stageVariables") or {}


class APIGatewayAuthorizerResponseV2:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Any, Dict, List, Optional
from functools import cached_property
from typing import Any, Dict, List, MutableMapping, Optional

from requests.structures import CaseInsensitiveDict

from aws_lambda_powertools.shared.headers_serializer import (
BaseHeadersSerializer,
Expand Down Expand Up @@ -112,8 +115,8 @@ def resource(self) -> str:
return self["resource"]

@property
def multi_value_headers(self) -> Dict[str, List[str]]:
return self.get("multiValueHeaders") or {} # key might exist but can be `null`
def multi_value_headers(self) -> MutableMapping[str, List[str]]:
return CaseInsensitiveDict(self.get("multiValueHeaders"))

@property
def multi_value_query_string_parameters(self) -> Dict[str, List[str]]:
Expand All @@ -127,27 +130,20 @@ def resolved_query_string_parameters(self) -> Dict[str, List[str]]:
return super().resolved_query_string_parameters

@property
def resolved_headers_field(self) -> Dict[str, Any]:
headers: Dict[str, Any] = {}

if self.multi_value_headers:
headers = self.multi_value_headers
else:
headers = self.headers

return {key.lower(): value for key, value in headers.items()}
def resolved_headers_field(self) -> MutableMapping[str, Any]:
return self.multi_value_headers or self.headers

@property
def request_context(self) -> APIGatewayEventRequestContext:
return APIGatewayEventRequestContext(self._data)

@property
def path_parameters(self) -> Optional[Dict[str, str]]:
return self.get("pathParameters")
def path_parameters(self) -> Dict[str, str]:
return self.get("pathParameters") or {}

@property
def stage_variables(self) -> Optional[Dict[str, str]]:
return self.get("stageVariables")
def stage_variables(self) -> Dict[str, str]:
return self.get("stageVariables") or {}

def header_serializer(self) -> BaseHeadersSerializer:
return MultiValueHeadersSerializer()
Expand Down Expand Up @@ -289,20 +285,20 @@ def raw_query_string(self) -> str:
return self["rawQueryString"]

@property
def cookies(self) -> Optional[List[str]]:
return self.get("cookies")
def cookies(self) -> List[str]:
return self.get("cookies") or []

@property
def request_context(self) -> RequestContextV2:
return RequestContextV2(self._data)

@property
def path_parameters(self) -> Optional[Dict[str, str]]:
return self.get("pathParameters")
def path_parameters(self) -> Dict[str, str]:
return self.get("pathParameters") or {}

@property
def stage_variables(self) -> Optional[Dict[str, str]]:
return self.get("stageVariables")
def stage_variables(self) -> Dict[str, str]:
return self.get("stageVariables") or {}

@property
def path(self) -> str:
Expand All @@ -319,10 +315,6 @@ def http_method(self) -> str:
def header_serializer(self):
return HttpApiHeadersSerializer()

@property
def resolved_headers_field(self) -> Dict[str, Any]:
if self.headers is not None:
headers = {key.lower(): value.split(",") if "," in value else value for key, value in self.headers.items()}
return headers

return {}
@cached_property
def resolved_headers_field(self) -> MutableMapping[str, Any]:
return CaseInsensitiveDict({k: v.split(",") if "," in v else v for k, v in self.headers.items()})
Loading

0 comments on commit 722d27f

Please sign in to comment.