Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(data_classes): return empty dict or list instead of None #4606

Merged
6 changes: 1 addition & 5 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,11 +817,7 @@ def _has_compression_enabled(
bool
True if compression is enabled and the "gzip" encoding is accepted, False otherwise.
"""
encoding: str = event.get_header_value(
name="accept-encoding",
default_value="",
case_sensitive=False,
) # noqa: E501
encoding = event.headers.get("accept-encoding", "")
if "gzip" in encoding:
if response_compression is not None:
return response_compression # e.g., Response(compress=False/True))
Expand Down
2 changes: 1 addition & 1 deletion aws_lambda_powertools/event_handler/appsync.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def handler(event, context: LambdaContext):
class MyCustomModel(AppSyncResolverEvent):
@property
def country_viewer(self) -> str:
return self.request_headers.get("cloudfront-viewer-country")
return self.request_headers.get("cloudfront-viewer-country", "")


@app.resolver(field_name="listLocations")
Expand Down
5 changes: 1 addition & 4 deletions aws_lambda_powertools/event_handler/middlewares/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,7 @@ def __init__(self, header: str):
def handler(self, app: APIGatewayRestResolver, next_middleware: NextMiddleware) -> Response:
# BEFORE logic
request_id = app.current_event.request_context.request_id
correlation_id = app.current_event.get_header_value(
name=self.header,
default_value=request_id,
)
correlation_id = app.current_event.headers.get(self.header, request_id)

# Call next middleware or route handler ('/todos')
response = next_middleware(app)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import logging
from copy import deepcopy
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple
from typing import Any, Callable, Dict, List, Mapping, MutableMapping, Optional, Sequence, Tuple

from pydantic import BaseModel

Expand Down Expand Up @@ -237,8 +237,8 @@ def _get_body(self, app: EventHandlerInstance) -> Dict[str, Any]:
Get the request body from the event, and parse it as JSON.
"""

content_type_value = app.current_event.get_header_value("content-type")
if not content_type_value or content_type_value.strip().startswith("application/json"):
content_type = app.current_event.headers.get("content-type")
if not content_type or content_type.strip().startswith("application/json"):
try:
return app.current_event.json_body
except json.JSONDecodeError as e:
Expand Down Expand Up @@ -410,7 +410,7 @@ def _normalize_multi_query_string_with_param(
return resolved_query_string


def _normalize_multi_header_values_with_param(headers: Dict[str, Any], params: Sequence[ModelField]):
def _normalize_multi_header_values_with_param(headers: MutableMapping[str, Any], params: Sequence[ModelField]):
leandrodamascena marked this conversation as resolved.
Show resolved Hide resolved
"""
Extract and normalize resolved_headers_field

Expand Down
16 changes: 4 additions & 12 deletions aws_lambda_powertools/event_handler/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from typing import Any, Dict

from aws_lambda_powertools.utilities.data_classes.shared_functions import get_header_value
from typing import Any, Mapping, Optional


class _FrozenDict(dict):
Expand All @@ -18,25 +16,19 @@ def __hash__(self):
return hash(frozenset(self.keys()))


def extract_origin_header(resolver_headers: Dict[str, Any]):
def extract_origin_header(resolved_headers: Mapping[str, Any]) -> Optional[str]:
"""
Extracts the 'origin' or 'Origin' header from the provided resolver headers.

The 'origin' or 'Origin' header can be either a single header or a multi-header.

Args:
resolver_headers (Dict): A dictionary containing the headers.
resolved_headers (Mapping): A dictionary containing the headers.

Returns:
Optional[str]: The value(s) of the origin header or None.
"""
resolved_header = get_header_value(
headers=resolver_headers,
name="origin",
default_value=None,
case_sensitive=False,
)
resolved_header = resolved_headers.get("origin")
if isinstance(resolved_header, list):
return resolved_header[0]

return resolved_header
23 changes: 7 additions & 16 deletions aws_lambda_powertools/utilities/data_classes/alb_event.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List

from aws_lambda_powertools.shared.headers_serializer import (
BaseHeadersSerializer,
Expand All @@ -7,6 +7,7 @@
)
from aws_lambda_powertools.utilities.data_classes.common import (
BaseProxyEvent,
CaseInsensitiveDict,
DictWrapper,
)

Expand Down Expand Up @@ -37,25 +38,15 @@ def multi_value_query_string_parameters(self) -> Dict[str, List[str]]:

@property
def resolved_query_string_parameters(self) -> Dict[str, List[str]]:
if self.multi_value_query_string_parameters:
return self.multi_value_query_string_parameters

return super().resolved_query_string_parameters
return self.multi_value_query_string_parameters or super().resolved_query_string_parameters

ericbn marked this conversation as resolved.
Show resolved Hide resolved
@property
def resolved_headers_field(self) -> Dict[str, Any]:
headers: Dict[str, Any] = {}

if self.multi_value_headers:
headers = self.multi_value_headers
else:
headers = self.headers

return {key.lower(): value for key, value in headers.items()}
def multi_value_headers(self) -> Dict[str, List[str]]:
return CaseInsensitiveDict(self.get("multiValueHeaders"))

leandrodamascena marked this conversation as resolved.
Show resolved Hide resolved
@property
def multi_value_headers(self) -> Optional[Dict[str, List[str]]]:
return self.get("multiValueHeaders")
def resolved_headers_field(self) -> Dict[str, Any]:
return self.multi_value_headers or self.headers

ericbn marked this conversation as resolved.
Show resolved Hide resolved
def header_serializer(self) -> BaseHeadersSerializer:
# When using the ALB integration, the `multiValueHeaders` feature can be disabled (default) or enabled.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import enum
import re
from typing import Any, Dict, List, Optional, overload
from typing import Any, Dict, List, Optional

from aws_lambda_powertools.utilities.data_classes.common import (
BaseRequestContext,
BaseRequestContextV2,
CaseInsensitiveDict,
DictWrapper,
)
from aws_lambda_powertools.utilities.data_classes.shared_functions import (
get_header_value,
)


class APIGatewayRouteArn:
Expand Down Expand Up @@ -144,7 +142,7 @@ def http_method(self) -> str:

@property
def headers(self) -> Dict[str, str]:
return self["headers"]
return CaseInsensitiveDict(self["headers"])

@property
def query_string_parameters(self) -> Dict[str, str]:
Expand All @@ -162,45 +160,6 @@ def stage_variables(self) -> Dict[str, str]:
def request_context(self) -> BaseRequestContext:
return BaseRequestContext(self._data)

@overload
def get_header_value(
self,
name: str,
default_value: str,
case_sensitive: bool = False,
) -> str: ...

@overload
def get_header_value(
self,
name: str,
default_value: Optional[str] = None,
case_sensitive: bool = False,
) -> Optional[str]: ...

def get_header_value(
self,
name: str,
default_value: Optional[str] = None,
case_sensitive: bool = False,
) -> Optional[str]:
"""Get header value by name

Parameters
----------
name: str
Header name
default_value: str, optional
Default value if no value was found by name
case_sensitive: bool
Whether to use a case-sensitive look up
Returns
-------
str, optional
Header value
"""
return get_header_value(self.headers, name, default_value, case_sensitive)


class APIGatewayAuthorizerEventV2(DictWrapper):
"""API Gateway Authorizer Event Format 2.0
Expand Down Expand Up @@ -234,14 +193,14 @@ def parsed_arn(self) -> APIGatewayRouteArn:
return parse_api_gateway_arn(self.route_arn)

@property
def identity_source(self) -> Optional[List[str]]:
def identity_source(self) -> List[str]:
"""The identity source for which authorization is requested.

For a REQUEST authorizer, this is optional. The value is a set of one or more mapping expressions of the
specified request parameters. The identity source can be headers, query string parameters, stage variables,
and context parameters.
"""
return self.get("identitySource")
return self.get("identitySource") or []

@property
def route_key(self) -> str:
Expand All @@ -265,7 +224,7 @@ def cookies(self) -> List[str]:
@property
def headers(self) -> Dict[str, str]:
"""Http headers"""
return self["headers"]
return CaseInsensitiveDict(self["headers"])

@property
def query_string_parameters(self) -> Dict[str, str]:
Expand All @@ -276,46 +235,12 @@ def request_context(self) -> BaseRequestContextV2:
return BaseRequestContextV2(self._data)

@property
def path_parameters(self) -> Optional[Dict[str, str]]:
return self.get("pathParameters")
def path_parameters(self) -> Dict[str, str]:
return self.get("pathParameters") or {}

@property
def stage_variables(self) -> Optional[Dict[str, str]]:
return self.get("stageVariables")

@overload
def get_header_value(self, name: str, default_value: str, case_sensitive: bool = False) -> str: ...

@overload
def get_header_value(
self,
name: str,
default_value: Optional[str] = None,
case_sensitive: bool = False,
) -> Optional[str]: ...

def get_header_value(
self,
name: str,
default_value: Optional[str] = None,
case_sensitive: bool = False,
) -> Optional[str]:
"""Get header value by name

Parameters
----------
name: str
Header name
default_value: str, optional
Default value if no value was found by name
case_sensitive: bool
Whether to use a case-sensitive look up
Returns
-------
str, optional
Header value
"""
return get_header_value(self.headers, name, default_value, case_sensitive)
def stage_variables(self) -> Dict[str, str]:
return self.get("stageVariables") or {}


class APIGatewayAuthorizerResponseV2:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import cached_property
from typing import Any, Dict, List, Optional

from aws_lambda_powertools.shared.headers_serializer import (
Expand All @@ -9,6 +10,7 @@
BaseProxyEvent,
BaseRequestContext,
BaseRequestContextV2,
CaseInsensitiveDict,
DictWrapper,
)

Expand Down Expand Up @@ -113,7 +115,7 @@ def resource(self) -> str:

@property
def multi_value_headers(self) -> Dict[str, List[str]]:
return self.get("multiValueHeaders") or {} # key might exist but can be `null`
return CaseInsensitiveDict(self.get("multiValueHeaders"))

@property
def multi_value_query_string_parameters(self) -> Dict[str, List[str]]:
Expand All @@ -128,26 +130,19 @@ def resolved_query_string_parameters(self) -> Dict[str, List[str]]:

@property
def resolved_headers_field(self) -> Dict[str, Any]:
headers: Dict[str, Any] = {}

if self.multi_value_headers:
headers = self.multi_value_headers
else:
headers = self.headers

return {key.lower(): value for key, value in headers.items()}
return self.multi_value_headers or self.headers

@property
def request_context(self) -> APIGatewayEventRequestContext:
return APIGatewayEventRequestContext(self._data)

@property
def path_parameters(self) -> Optional[Dict[str, str]]:
return self.get("pathParameters")
def path_parameters(self) -> Dict[str, str]:
return self.get("pathParameters") or {}

@property
def stage_variables(self) -> Optional[Dict[str, str]]:
return self.get("stageVariables")
def stage_variables(self) -> Dict[str, str]:
return self.get("stageVariables") or {}

def header_serializer(self) -> BaseHeadersSerializer:
return MultiValueHeadersSerializer()
Expand Down Expand Up @@ -289,20 +284,20 @@ def raw_query_string(self) -> str:
return self["rawQueryString"]

@property
def cookies(self) -> Optional[List[str]]:
return self.get("cookies")
def cookies(self) -> List[str]:
return self.get("cookies") or []

@property
def request_context(self) -> RequestContextV2:
return RequestContextV2(self._data)

@property
def path_parameters(self) -> Optional[Dict[str, str]]:
return self.get("pathParameters")
def path_parameters(self) -> Dict[str, str]:
return self.get("pathParameters") or {}

@property
def stage_variables(self) -> Optional[Dict[str, str]]:
return self.get("stageVariables")
def stage_variables(self) -> Dict[str, str]:
return self.get("stageVariables") or {}

@property
def path(self) -> str:
Expand All @@ -319,10 +314,6 @@ def http_method(self) -> str:
def header_serializer(self):
return HttpApiHeadersSerializer()

@property
@cached_property
def resolved_headers_field(self) -> Dict[str, Any]:
if self.headers is not None:
headers = {key.lower(): value.split(",") if "," in value else value for key, value in self.headers.items()}
return headers

return {}
return CaseInsensitiveDict((k, v.split(",") if "," in v else v) for k, v in self.headers.items())
Loading
Loading