From 722d27f8022fa249a1e71b25c6092b1f13adae50 Mon Sep 17 00:00:00 2001 From: Eric Nielsen <4120606+ericbn@users.noreply.github.com> Date: Sat, 22 Jun 2024 13:13:55 -0500 Subject: [PATCH] feat(data_classes): return empty dict or list instead of None This simplifies the code internally and also for users. Also wrap all headers in CaseInsensitiveDict from requests. These changes replace the need of utility functions like get_header_value, get_query_string_value or get_multi_value_query_string_values, which are removed. --- CHANGELOG.md | 1 + .../event_handler/api_gateway.py | 6 +- .../event_handler/appsync.py | 2 +- .../event_handler/middlewares/base.py | 5 +- .../middlewares/openapi_validation.py | 8 +- aws_lambda_powertools/event_handler/util.py | 16 +-- .../utilities/data_classes/alb_event.py | 24 ++-- .../api_gateway_authorizer_event.py | 100 ++------------ .../data_classes/api_gateway_proxy_event.py | 50 +++---- .../data_classes/appsync_resolver_event.py | 72 +++------- .../data_classes/bedrock_agent_event.py | 10 +- .../data_classes/cloud_watch_alarm_event.py | 12 +- .../data_classes/cloud_watch_logs_event.py | 4 +- .../data_classes/code_pipeline_job_event.py | 7 +- .../data_classes/cognito_user_pool_event.py | 68 ++++----- .../utilities/data_classes/common.py | 123 ++--------------- .../data_classes/dynamo_db_stream_event.py | 19 ++- .../utilities/data_classes/kafka_event.py | 43 +----- .../data_classes/s3_batch_operation_event.py | 4 +- .../utilities/data_classes/s3_object_event.py | 50 +------ .../utilities/data_classes/ses_event.py | 16 +-- .../data_classes/shared_functions.py | 129 ------------------ .../utilities/data_classes/vpc_lattice.py | 110 ++------------- docs/core/event_handler/api_gateway.md | 4 +- docs/utilities/data_classes.md | 6 +- .../src/custom_models.py | 10 +- .../src/accessing_request_details.py | 4 +- .../src/accessing_request_details_headers.py | 2 +- .../src/exception_handling.py | 2 +- .../src/middleware_extending_middlewares.py | 5 +- .../middleware_global_middlewares_module.py | 2 +- .../src/split_route_module.py | 8 +- .../src/split_route_prefix_module.py | 8 +- .../event_handler/test_api_middlewares.py | 5 +- tests/unit/data_classes/test_alb_event.py | 2 +- .../test_api_gateway_authorizer_event.py | 14 +- .../test_api_gateway_proxy_event.py | 8 +- .../test_appsync_resolver_event.py | 19 ++- .../test_cloud_watch_alarm_event.py | 1 + .../test_cloud_watch_logs_event.py | 4 +- .../test_code_pipeline_job_event.py | 4 +- .../test_cognito_user_pool_event.py | 29 ++-- .../test_dynamo_db_stream_event.py | 6 +- tests/unit/data_classes/test_kafka_event.py | 4 +- .../data_classes/test_lambda_function_url.py | 12 +- .../test_s3_batch_operation_event.py | 2 +- .../unit/data_classes/test_s3_object_event.py | 2 +- tests/unit/data_classes/test_ses_event.py | 8 +- .../data_classes/test_vpc_lattice_event.py | 4 +- .../data_classes/test_vpc_lattice_eventv2.py | 4 +- tests/unit/test_data_classes.py | 86 +----------- 51 files changed, 255 insertions(+), 889 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 859e0cbe7f1..d1d3013db0b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ ## Features +* **data_classes:** return empty dict or list in Event Source Data Classes instead of None, and return case-insensitive dict for headrs. ([#2605](https://github.com/aws-powertools/powertools-lambda-python/issues/2605)) * **event_source:** add CloudFormationCustomResourceEvent data class. ([#4342](https://github.com/aws-powertools/powertools-lambda-python/issues/4342)) ## Maintenance diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 4fa9eb3eb97..a20154b4bbf 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -817,11 +817,7 @@ def _has_compression_enabled( bool True if compression is enabled and the "gzip" encoding is accepted, False otherwise. """ - encoding: str = event.get_header_value( - name="accept-encoding", - default_value="", - case_sensitive=False, - ) # noqa: E501 + encoding = event.headers.get("accept-encoding", "") if "gzip" in encoding: if response_compression is not None: return response_compression # e.g., Response(compress=False/True)) diff --git a/aws_lambda_powertools/event_handler/appsync.py b/aws_lambda_powertools/event_handler/appsync.py index fba5681ef6a..d1f690a7212 100644 --- a/aws_lambda_powertools/event_handler/appsync.py +++ b/aws_lambda_powertools/event_handler/appsync.py @@ -126,7 +126,7 @@ def handler(event, context: LambdaContext): class MyCustomModel(AppSyncResolverEvent): @property - def country_viewer(self) -> str: + def country_viewer(self): return self.request_headers.get("cloudfront-viewer-country") diff --git a/aws_lambda_powertools/event_handler/middlewares/base.py b/aws_lambda_powertools/event_handler/middlewares/base.py index fb4bf37cc74..342b033ec1f 100644 --- a/aws_lambda_powertools/event_handler/middlewares/base.py +++ b/aws_lambda_powertools/event_handler/middlewares/base.py @@ -47,10 +47,7 @@ def __init__(self, header: str): def handler(self, app: APIGatewayRestResolver, next_middleware: NextMiddleware) -> Response: # BEFORE logic request_id = app.current_event.request_context.request_id - correlation_id = app.current_event.get_header_value( - name=self.header, - default_value=request_id, - ) + correlation_id = app.current_event.headers.get(self.header, request_id) # Call next middleware or route handler ('/todos') response = next_middleware(app) diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index 2eafb0d67bb..12b70987f8a 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -2,7 +2,7 @@ import json import logging from copy import deepcopy -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple +from typing import Any, Callable, Dict, List, Mapping, MutableMapping, Optional, Sequence, Tuple from pydantic import BaseModel @@ -237,8 +237,8 @@ def _get_body(self, app: EventHandlerInstance) -> Dict[str, Any]: Get the request body from the event, and parse it as JSON. """ - content_type_value = app.current_event.get_header_value("content-type") - if not content_type_value or content_type_value.strip().startswith("application/json"): + content_type = app.current_event.headers.get("content-type") + if not content_type or content_type.strip().startswith("application/json"): try: return app.current_event.json_body except json.JSONDecodeError as e: @@ -410,7 +410,7 @@ def _normalize_multi_query_string_with_param( return resolved_query_string -def _normalize_multi_header_values_with_param(headers: Dict[str, Any], params: Sequence[ModelField]): +def _normalize_multi_header_values_with_param(headers: MutableMapping[str, Any], params: Sequence[ModelField]): """ Extract and normalize resolved_headers_field diff --git a/aws_lambda_powertools/event_handler/util.py b/aws_lambda_powertools/event_handler/util.py index 6f2caf10858..9981e392f82 100644 --- a/aws_lambda_powertools/event_handler/util.py +++ b/aws_lambda_powertools/event_handler/util.py @@ -1,6 +1,4 @@ -from typing import Any, Dict - -from aws_lambda_powertools.utilities.data_classes.shared_functions import get_header_value +from typing import Any, Mapping, Optional class _FrozenDict(dict): @@ -18,25 +16,19 @@ def __hash__(self): return hash(frozenset(self.keys())) -def extract_origin_header(resolver_headers: Dict[str, Any]): +def extract_origin_header(resolved_headers: Mapping[str, Any]) -> Optional[str]: """ Extracts the 'origin' or 'Origin' header from the provided resolver headers. The 'origin' or 'Origin' header can be either a single header or a multi-header. Args: - resolver_headers (Dict): A dictionary containing the headers. + resolved_headers (Mapping): A dictionary containing the headers. Returns: Optional[str]: The value(s) of the origin header or None. """ - resolved_header = get_header_value( - headers=resolver_headers, - name="origin", - default_value=None, - case_sensitive=False, - ) + resolved_header = resolved_headers.get("origin") if isinstance(resolved_header, list): return resolved_header[0] - return resolved_header diff --git a/aws_lambda_powertools/utilities/data_classes/alb_event.py b/aws_lambda_powertools/utilities/data_classes/alb_event.py index a3fbb24f270..1c4d53040b7 100644 --- a/aws_lambda_powertools/utilities/data_classes/alb_event.py +++ b/aws_lambda_powertools/utilities/data_classes/alb_event.py @@ -1,4 +1,6 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, MutableMapping + +from requests.structures import CaseInsensitiveDict from aws_lambda_powertools.shared.headers_serializer import ( BaseHeadersSerializer, @@ -37,25 +39,15 @@ def multi_value_query_string_parameters(self) -> Dict[str, List[str]]: @property def resolved_query_string_parameters(self) -> Dict[str, List[str]]: - if self.multi_value_query_string_parameters: - return self.multi_value_query_string_parameters - - return super().resolved_query_string_parameters + return self.multi_value_query_string_parameters or super().resolved_query_string_parameters @property - def resolved_headers_field(self) -> Dict[str, Any]: - headers: Dict[str, Any] = {} - - if self.multi_value_headers: - headers = self.multi_value_headers - else: - headers = self.headers - - return {key.lower(): value for key, value in headers.items()} + def multi_value_headers(self) -> MutableMapping[str, List[str]]: + return CaseInsensitiveDict(self.get("multiValueHeaders")) @property - def multi_value_headers(self) -> Optional[Dict[str, List[str]]]: - return self.get("multiValueHeaders") + def resolved_headers_field(self) -> MutableMapping[str, Any]: + return self.multi_value_headers or self.headers def header_serializer(self) -> BaseHeadersSerializer: # When using the ALB integration, the `multiValueHeaders` feature can be disabled (default) or enabled. diff --git a/aws_lambda_powertools/utilities/data_classes/api_gateway_authorizer_event.py b/aws_lambda_powertools/utilities/data_classes/api_gateway_authorizer_event.py index b87c8ddaf20..8840f856630 100644 --- a/aws_lambda_powertools/utilities/data_classes/api_gateway_authorizer_event.py +++ b/aws_lambda_powertools/utilities/data_classes/api_gateway_authorizer_event.py @@ -1,15 +1,14 @@ import enum import re -from typing import Any, Dict, List, Optional, overload +from typing import Any, Dict, List, MutableMapping, Optional + +from requests.structures import CaseInsensitiveDict from aws_lambda_powertools.utilities.data_classes.common import ( BaseRequestContext, BaseRequestContextV2, DictWrapper, ) -from aws_lambda_powertools.utilities.data_classes.shared_functions import ( - get_header_value, -) class APIGatewayRouteArn: @@ -143,8 +142,8 @@ def http_method(self) -> str: return self["httpMethod"] @property - def headers(self) -> Dict[str, str]: - return self["headers"] + def headers(self) -> MutableMapping[str, str]: + return CaseInsensitiveDict(self["headers"]) @property def query_string_parameters(self) -> Dict[str, str]: @@ -162,45 +161,6 @@ def stage_variables(self) -> Dict[str, str]: def request_context(self) -> BaseRequestContext: return BaseRequestContext(self._data) - @overload - def get_header_value( - self, - name: str, - default_value: str, - case_sensitive: bool = False, - ) -> str: ... - - @overload - def get_header_value( - self, - name: str, - default_value: Optional[str] = None, - case_sensitive: bool = False, - ) -> Optional[str]: ... - - def get_header_value( - self, - name: str, - default_value: Optional[str] = None, - case_sensitive: bool = False, - ) -> Optional[str]: - """Get header value by name - - Parameters - ---------- - name: str - Header name - default_value: str, optional - Default value if no value was found by name - case_sensitive: bool - Whether to use a case-sensitive look up - Returns - ------- - str, optional - Header value - """ - return get_header_value(self.headers, name, default_value, case_sensitive) - class APIGatewayAuthorizerEventV2(DictWrapper): """API Gateway Authorizer Event Format 2.0 @@ -234,14 +194,14 @@ def parsed_arn(self) -> APIGatewayRouteArn: return parse_api_gateway_arn(self.route_arn) @property - def identity_source(self) -> Optional[List[str]]: + def identity_source(self) -> List[str]: """The identity source for which authorization is requested. For a REQUEST authorizer, this is optional. The value is a set of one or more mapping expressions of the specified request parameters. The identity source can be headers, query string parameters, stage variables, and context parameters. """ - return self.get("identitySource") + return self.get("identitySource") or [] @property def route_key(self) -> str: @@ -263,9 +223,9 @@ def cookies(self) -> List[str]: return self["cookies"] @property - def headers(self) -> Dict[str, str]: + def headers(self) -> MutableMapping[str, str]: """Http headers""" - return self["headers"] + return CaseInsensitiveDict(self["headers"]) @property def query_string_parameters(self) -> Dict[str, str]: @@ -276,46 +236,12 @@ def request_context(self) -> BaseRequestContextV2: return BaseRequestContextV2(self._data) @property - def path_parameters(self) -> Optional[Dict[str, str]]: - return self.get("pathParameters") + def path_parameters(self) -> Dict[str, str]: + return self.get("pathParameters") or {} @property - def stage_variables(self) -> Optional[Dict[str, str]]: - return self.get("stageVariables") - - @overload - def get_header_value(self, name: str, default_value: str, case_sensitive: bool = False) -> str: ... - - @overload - def get_header_value( - self, - name: str, - default_value: Optional[str] = None, - case_sensitive: bool = False, - ) -> Optional[str]: ... - - def get_header_value( - self, - name: str, - default_value: Optional[str] = None, - case_sensitive: bool = False, - ) -> Optional[str]: - """Get header value by name - - Parameters - ---------- - name: str - Header name - default_value: str, optional - Default value if no value was found by name - case_sensitive: bool - Whether to use a case-sensitive look up - Returns - ------- - str, optional - Header value - """ - return get_header_value(self.headers, name, default_value, case_sensitive) + def stage_variables(self) -> Dict[str, str]: + return self.get("stageVariables") or {} class APIGatewayAuthorizerResponseV2: diff --git a/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py b/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py index 48d3c96c84c..c6787c6f53c 100644 --- a/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py +++ b/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py @@ -1,4 +1,7 @@ -from typing import Any, Dict, List, Optional +from functools import cached_property +from typing import Any, Dict, List, MutableMapping, Optional + +from requests.structures import CaseInsensitiveDict from aws_lambda_powertools.shared.headers_serializer import ( BaseHeadersSerializer, @@ -112,8 +115,8 @@ def resource(self) -> str: return self["resource"] @property - def multi_value_headers(self) -> Dict[str, List[str]]: - return self.get("multiValueHeaders") or {} # key might exist but can be `null` + def multi_value_headers(self) -> MutableMapping[str, List[str]]: + return CaseInsensitiveDict(self.get("multiValueHeaders")) @property def multi_value_query_string_parameters(self) -> Dict[str, List[str]]: @@ -127,27 +130,20 @@ def resolved_query_string_parameters(self) -> Dict[str, List[str]]: return super().resolved_query_string_parameters @property - def resolved_headers_field(self) -> Dict[str, Any]: - headers: Dict[str, Any] = {} - - if self.multi_value_headers: - headers = self.multi_value_headers - else: - headers = self.headers - - return {key.lower(): value for key, value in headers.items()} + def resolved_headers_field(self) -> MutableMapping[str, Any]: + return self.multi_value_headers or self.headers @property def request_context(self) -> APIGatewayEventRequestContext: return APIGatewayEventRequestContext(self._data) @property - def path_parameters(self) -> Optional[Dict[str, str]]: - return self.get("pathParameters") + def path_parameters(self) -> Dict[str, str]: + return self.get("pathParameters") or {} @property - def stage_variables(self) -> Optional[Dict[str, str]]: - return self.get("stageVariables") + def stage_variables(self) -> Dict[str, str]: + return self.get("stageVariables") or {} def header_serializer(self) -> BaseHeadersSerializer: return MultiValueHeadersSerializer() @@ -289,20 +285,20 @@ def raw_query_string(self) -> str: return self["rawQueryString"] @property - def cookies(self) -> Optional[List[str]]: - return self.get("cookies") + def cookies(self) -> List[str]: + return self.get("cookies") or [] @property def request_context(self) -> RequestContextV2: return RequestContextV2(self._data) @property - def path_parameters(self) -> Optional[Dict[str, str]]: - return self.get("pathParameters") + def path_parameters(self) -> Dict[str, str]: + return self.get("pathParameters") or {} @property - def stage_variables(self) -> Optional[Dict[str, str]]: - return self.get("stageVariables") + def stage_variables(self) -> Dict[str, str]: + return self.get("stageVariables") or {} @property def path(self) -> str: @@ -319,10 +315,6 @@ def http_method(self) -> str: def header_serializer(self): return HttpApiHeadersSerializer() - @property - def resolved_headers_field(self) -> Dict[str, Any]: - if self.headers is not None: - headers = {key.lower(): value.split(",") if "," in value else value for key, value in self.headers.items()} - return headers - - return {} + @cached_property + def resolved_headers_field(self) -> MutableMapping[str, Any]: + return CaseInsensitiveDict({k: v.split(",") if "," in v else v for k, v in self.headers.items()}) diff --git a/aws_lambda_powertools/utilities/data_classes/appsync_resolver_event.py b/aws_lambda_powertools/utilities/data_classes/appsync_resolver_event.py index f58308377ff..c8e0949f4b1 100644 --- a/aws_lambda_powertools/utilities/data_classes/appsync_resolver_event.py +++ b/aws_lambda_powertools/utilities/data_classes/appsync_resolver_event.py @@ -1,9 +1,8 @@ -from typing import Any, Dict, List, Optional, Union, overload +from typing import Any, Dict, List, MutableMapping, Optional, Union + +from requests.structures import CaseInsensitiveDict from aws_lambda_powertools.utilities.data_classes.common import DictWrapper -from aws_lambda_powertools.utilities.data_classes.shared_functions import ( - get_header_value, -) def get_identity_object(identity: Optional[dict]) -> Any: @@ -118,15 +117,15 @@ def parent_type_name(self) -> str: return self["parentTypeName"] @property - def variables(self) -> Optional[Dict[str, str]]: + def variables(self) -> Dict[str, str]: """A map which holds all variables that are passed into the GraphQL request.""" - return self.get("variables") + return self.get("variables") or {} @property - def selection_set_list(self) -> Optional[List[str]]: + def selection_set_list(self) -> List[str]: """A list representation of the fields in the GraphQL selection set. Fields that are aliased will only be referenced by the alias name, not the field name.""" - return self.get("selectionSetList") + return self.get("selectionSetList") or [] @property def selection_set_graphql(self) -> Optional[str]: @@ -184,22 +183,22 @@ def identity(self) -> Union[None, AppSyncIdentityIAM, AppSyncIdentityCognito]: return get_identity_object(self.get("identity")) @property - def source(self) -> Optional[Dict[str, Any]]: + def source(self) -> Dict[str, Any]: """A map that contains the resolution of the parent field.""" - return self.get("source") + return self.get("source") or {} @property - def request_headers(self) -> Dict[str, str]: + def request_headers(self) -> MutableMapping[str, str]: """Request headers""" - return self["request"]["headers"] + return CaseInsensitiveDict(self["request"]["headers"]) @property - def prev_result(self) -> Optional[Dict[str, Any]]: + def prev_result(self) -> Dict[str, Any]: """It represents the result of whatever previous operation was executed in a pipeline resolver.""" prev = self.get("prev") if not prev: - return None - return prev.get("result") + return {} + return prev.get("result") or {} @property def info(self) -> AppSyncResolverEventInfo: @@ -207,48 +206,9 @@ def info(self) -> AppSyncResolverEventInfo: return self._info @property - def stash(self) -> Optional[dict]: + def stash(self) -> dict: """The stash is a map that is made available inside each resolver and function mapping template. The same stash instance lives through a single resolver execution. This means that you can use the stash to pass arbitrary data across request and response mapping templates, and across functions in a pipeline resolver.""" - return self.get("stash") - - @overload - def get_header_value( - self, - name: str, - default_value: str, - case_sensitive: bool = False, - ) -> str: ... - - @overload - def get_header_value( - self, - name: str, - default_value: Optional[str] = None, - case_sensitive: bool = False, - ) -> Optional[str]: ... - - def get_header_value( - self, - name: str, - default_value: Optional[str] = None, - case_sensitive: bool = False, - ) -> Optional[str]: - """Get header value by name - - Parameters - ---------- - name: str - Header name - default_value: str, optional - Default value if no value was found by name - case_sensitive: bool - Whether to use a case-sensitive look up - Returns - ------- - str, optional - Header value - """ - return get_header_value(self.request_headers, name, default_value, case_sensitive) + return self.get("stash") or {} diff --git a/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py b/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py index 4c404c73111..98f77ef5df0 100644 --- a/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py +++ b/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py @@ -80,8 +80,9 @@ def http_method(self) -> str: return self["httpMethod"] @property - def parameters(self) -> Optional[List[BedrockAgentProperty]]: - return [BedrockAgentProperty(x) for x in self["parameters"]] if self.get("parameters") else None + def parameters(self) -> List[BedrockAgentProperty]: + parameters = self.get("parameters") or [] + return [BedrockAgentProperty(x) for x in parameters] @property def request_body(self) -> Optional[BedrockAgentRequestBody]: @@ -105,10 +106,11 @@ def path(self) -> str: return self["apiPath"] @property - def query_string_parameters(self) -> Optional[Dict[str, str]]: + def query_string_parameters(self) -> Dict[str, str]: # In Bedrock Agent events, query string parameters are passed as undifferentiated parameters, # together with the other parameters. So we just return all parameters here. - return {x["name"]: x["value"] for x in self["parameters"]} if self.get("parameters") else None + parameters = self.get("parameters") or [] + return {x["name"]: x["value"] for x in parameters} @property def resolved_headers_field(self) -> Dict[str, Any]: diff --git a/aws_lambda_powertools/utilities/data_classes/cloud_watch_alarm_event.py b/aws_lambda_powertools/utilities/data_classes/cloud_watch_alarm_event.py index d085228cb37..78106b576e0 100644 --- a/aws_lambda_powertools/utilities/data_classes/cloud_watch_alarm_event.py +++ b/aws_lambda_powertools/utilities/data_classes/cloud_watch_alarm_event.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import cached_property -from typing import Any, Dict, List, Literal, Optional +from typing import Any, List, Literal, Optional from aws_lambda_powertools.utilities.data_classes.common import DictWrapper @@ -117,11 +117,11 @@ def unit(self) -> Optional[str]: return self.get("unit", None) @property - def metric(self) -> Optional[Dict]: + def metric(self) -> dict: """ Metric details """ - return self.get("metric", {}) + return self.get("metric") or {} class CloudWatchAlarmData(DictWrapper): @@ -191,12 +191,12 @@ def alarm_actions_suppressor_extension_period(self) -> Optional[str]: return self.get("actionsSuppressorExtensionPeriod", None) @property - def metrics(self) -> Optional[List[CloudWatchAlarmMetric]]: + def metrics(self) -> List[CloudWatchAlarmMetric]: """ The metrics evaluated for the Alarm. """ - metrics = self.get("metrics") - return [CloudWatchAlarmMetric(i) for i in metrics] if metrics else None + metrics = self.get("metrics") or [] + return [CloudWatchAlarmMetric(i) for i in metrics] class CloudWatchAlarmEvent(DictWrapper): diff --git a/aws_lambda_powertools/utilities/data_classes/cloud_watch_logs_event.py b/aws_lambda_powertools/utilities/data_classes/cloud_watch_logs_event.py index 7775dd67333..7a5fe7cec76 100644 --- a/aws_lambda_powertools/utilities/data_classes/cloud_watch_logs_event.py +++ b/aws_lambda_powertools/utilities/data_classes/cloud_watch_logs_event.py @@ -23,9 +23,9 @@ def message(self) -> str: return self["message"] @property - def extracted_fields(self) -> Optional[Dict[str, str]]: + def extracted_fields(self) -> Dict[str, str]: """Get the `extractedFields` property""" - return self.get("extractedFields") + return self.get("extractedFields") or {} class CloudWatchLogsDecodedData(DictWrapper): diff --git a/aws_lambda_powertools/utilities/data_classes/code_pipeline_job_event.py b/aws_lambda_powertools/utilities/data_classes/code_pipeline_job_event.py index cc7a75cc05e..1cc409c6988 100644 --- a/aws_lambda_powertools/utilities/data_classes/code_pipeline_job_event.py +++ b/aws_lambda_powertools/utilities/data_classes/code_pipeline_job_event.py @@ -19,12 +19,11 @@ def user_parameters(self) -> Optional[str]: return self.get("UserParameters", None) @cached_property - def decoded_user_parameters(self) -> Optional[Dict[str, Any]]: + def decoded_user_parameters(self) -> Dict[str, Any]: """Json Decoded user parameters""" if self.user_parameters is not None: return self._json_deserializer(self.user_parameters) - - return None + return {} class CodePipelineActionConfiguration(DictWrapper): @@ -177,7 +176,7 @@ def user_parameters(self) -> Optional[str]: return self.data.action_configuration.configuration.user_parameters @property - def decoded_user_parameters(self) -> Optional[Dict[str, Any]]: + def decoded_user_parameters(self) -> Dict[str, Any]: """Json Decoded action configuration user parameters""" return self.data.action_configuration.configuration.decoded_user_parameters diff --git a/aws_lambda_powertools/utilities/data_classes/cognito_user_pool_event.py b/aws_lambda_powertools/utilities/data_classes/cognito_user_pool_event.py index a97bf26a16f..86cf3b0601d 100644 --- a/aws_lambda_powertools/utilities/data_classes/cognito_user_pool_event.py +++ b/aws_lambda_powertools/utilities/data_classes/cognito_user_pool_event.py @@ -61,15 +61,15 @@ def user_attributes(self) -> Dict[str, str]: return self["request"]["userAttributes"] @property - def validation_data(self) -> Optional[Dict[str, str]]: + def validation_data(self) -> Dict[str, str]: """One or more name-value pairs containing the validation data in the request to register a user.""" - return self["request"].get("validationData") + return self["request"].get("validationData") or {} @property - def client_metadata(self) -> Optional[Dict[str, str]]: + def client_metadata(self) -> Dict[str, str]: """One or more key-value pairs that you can provide as custom input to the Lambda function that you specify for the pre sign-up trigger.""" - return self["request"].get("clientMetadata") + return self["request"].get("clientMetadata") or {} class PreSignUpTriggerEventResponse(DictWrapper): @@ -133,10 +133,10 @@ def user_attributes(self) -> Dict[str, str]: return self["request"]["userAttributes"] @property - def client_metadata(self) -> Optional[Dict[str, str]]: + def client_metadata(self) -> Dict[str, str]: """One or more key-value pairs that you can provide as custom input to the Lambda function that you specify for the post confirmation trigger.""" - return self["request"].get("clientMetadata") + return self["request"].get("clientMetadata") or {} class PostConfirmationTriggerEvent(BaseTriggerEvent): @@ -165,15 +165,15 @@ def password(self) -> str: return self["request"]["password"] @property - def validation_data(self) -> Optional[Dict[str, str]]: + def validation_data(self) -> Dict[str, str]: """One or more name-value pairs containing the validation data in the request to register a user.""" - return self["request"].get("validationData") + return self["request"].get("validationData") or {} @property - def client_metadata(self) -> Optional[Dict[str, str]]: + def client_metadata(self) -> Dict[str, str]: """One or more key-value pairs that you can provide as custom input to the Lambda function that you specify for the pre sign-up trigger.""" - return self["request"].get("clientMetadata") + return self["request"].get("clientMetadata") or {} class UserMigrationTriggerEventResponse(DictWrapper): @@ -213,8 +213,8 @@ def message_action(self, value: str): self["response"]["messageAction"] = value @property - def desired_delivery_mediums(self) -> Optional[List[str]]: - return self["response"].get("desiredDeliveryMediums") + def desired_delivery_mediums(self) -> List[str]: + return self["response"].get("desiredDeliveryMediums") or [] @desired_delivery_mediums.setter def desired_delivery_mediums(self, value: List[str]): @@ -281,10 +281,10 @@ def user_attributes(self) -> Dict[str, str]: return self["request"]["userAttributes"] @property - def client_metadata(self) -> Optional[Dict[str, str]]: + def client_metadata(self) -> Dict[str, str]: """One or more key-value pairs that you can provide as custom input to the Lambda function that you specify for the pre sign-up trigger.""" - return self["request"].get("clientMetadata") + return self["request"].get("clientMetadata") or {} class CustomMessageTriggerEventResponse(DictWrapper): @@ -361,9 +361,9 @@ def user_attributes(self) -> Dict[str, str]: return self["request"]["userAttributes"] @property - def validation_data(self) -> Optional[Dict[str, str]]: + def validation_data(self) -> Dict[str, str]: """One or more key-value pairs containing the validation data in the user's sign-in request.""" - return self["request"].get("validationData") + return self["request"].get("validationData") or {} class PreAuthenticationTriggerEvent(BaseTriggerEvent): @@ -402,10 +402,10 @@ def user_attributes(self) -> Dict[str, str]: return self["request"]["userAttributes"] @property - def client_metadata(self) -> Optional[Dict[str, str]]: + def client_metadata(self) -> Dict[str, str]: """One or more key-value pairs that you can provide as custom input to the Lambda function that you specify for the post authentication trigger.""" - return self["request"].get("clientMetadata") + return self["request"].get("clientMetadata") or {} class PostAuthenticationTriggerEvent(BaseTriggerEvent): @@ -433,14 +433,14 @@ def request(self) -> PostAuthenticationTriggerEventRequest: class GroupOverrideDetails(DictWrapper): @property - def groups_to_override(self) -> Optional[List[str]]: + def groups_to_override(self) -> List[str]: """A list of the group names that are associated with the user that the identity token is issued for.""" - return self.get("groupsToOverride") + return self.get("groupsToOverride") or [] @property - def iam_roles_to_override(self) -> Optional[List[str]]: + def iam_roles_to_override(self) -> List[str]: """A list of the current IAM roles associated with these groups.""" - return self.get("iamRolesToOverride") + return self.get("iamRolesToOverride") or [] @property def preferred_role(self) -> Optional[str]: @@ -460,16 +460,16 @@ def user_attributes(self) -> Dict[str, str]: return self["request"]["userAttributes"] @property - def client_metadata(self) -> Optional[Dict[str, str]]: + def client_metadata(self) -> Dict[str, str]: """One or more key-value pairs that you can provide as custom input to the Lambda function that you specify for the pre token generation trigger.""" - return self["request"].get("clientMetadata") + return self["request"].get("clientMetadata") or {} class ClaimsOverrideDetails(DictWrapper): @property - def claims_to_add_or_override(self) -> Optional[Dict[str, str]]: - return self.get("claimsToAddOrOverride") + def claims_to_add_or_override(self) -> Dict[str, str]: + return self.get("claimsToAddOrOverride") or {} @claims_to_add_or_override.setter def claims_to_add_or_override(self, value: Dict[str, str]): @@ -478,8 +478,8 @@ def claims_to_add_or_override(self, value: Dict[str, str]): self._data["claimsToAddOrOverride"] = value @property - def claims_to_suppress(self) -> Optional[List[str]]: - return self.get("claimsToSuppress") + def claims_to_suppress(self) -> List[str]: + return self.get("claimsToSuppress") or [] @claims_to_suppress.setter def claims_to_suppress(self, value: List[str]): @@ -599,10 +599,10 @@ def session(self) -> List[ChallengeResult]: return [ChallengeResult(result) for result in self["request"]["session"]] @property - def client_metadata(self) -> Optional[Dict[str, str]]: + def client_metadata(self) -> Dict[str, str]: """One or more key-value pairs that you can provide as custom input to the Lambda function that you specify for the defined auth challenge trigger.""" - return self["request"].get("clientMetadata") + return self["request"].get("clientMetadata") or {} class DefineAuthChallengeTriggerEventResponse(DictWrapper): @@ -685,10 +685,10 @@ def session(self) -> List[ChallengeResult]: return [ChallengeResult(result) for result in self["request"]["session"]] @property - def client_metadata(self) -> Optional[Dict[str, str]]: + def client_metadata(self) -> Dict[str, str]: """One or more key-value pairs that you can provide as custom input to the Lambda function that you specify for the creation auth challenge trigger.""" - return self["request"].get("clientMetadata") + return self["request"].get("clientMetadata") or {} class CreateAuthChallengeTriggerEventResponse(DictWrapper): @@ -773,10 +773,10 @@ def challenge_answer(self) -> Any: return self["request"]["challengeAnswer"] @property - def client_metadata(self) -> Optional[Dict[str, str]]: + def client_metadata(self) -> Dict[str, str]: """One or more key-value pairs that you can provide as custom input to the Lambda function that you specify for the "Verify Auth Challenge" trigger.""" - return self["request"].get("clientMetadata") + return self["request"].get("clientMetadata") or {} @property def user_not_found(self) -> Optional[bool]: diff --git a/aws_lambda_powertools/utilities/data_classes/common.py b/aws_lambda_powertools/utilities/data_classes/common.py index 76726ca5129..b2666450eb5 100644 --- a/aws_lambda_powertools/utilities/data_classes/common.py +++ b/aws_lambda_powertools/utilities/data_classes/common.py @@ -1,15 +1,11 @@ import base64 import json -from collections.abc import Mapping from functools import cached_property -from typing import Any, Callable, Dict, Iterator, List, Optional, overload +from typing import Any, Callable, Dict, Iterator, List, Mapping, MutableMapping, Optional + +from requests.structures import CaseInsensitiveDict from aws_lambda_powertools.shared.headers_serializer import BaseHeadersSerializer -from aws_lambda_powertools.utilities.data_classes.shared_functions import ( - get_header_value, - get_multi_value_query_string_values, - get_query_string_value, -) class DictWrapper(Mapping): @@ -97,18 +93,18 @@ def raw_event(self) -> Dict[str, Any]: class BaseProxyEvent(DictWrapper): @property - def headers(self) -> Dict[str, str]: - return self.get("headers") or {} + def headers(self) -> MutableMapping[str, str]: + return CaseInsensitiveDict(self.get("headers")) @property - def query_string_parameters(self) -> Optional[Dict[str, str]]: - return self.get("queryStringParameters") + def query_string_parameters(self) -> Dict[str, str]: + return self.get("queryStringParameters") or {} @property def multi_value_query_string_parameters(self) -> Dict[str, List[str]]: return self.get("multiValueQueryStringParameters") or {} - @property + @cached_property def resolved_query_string_parameters(self) -> Dict[str, List[str]]: """ This property determines the appropriate query string parameter to be used @@ -117,14 +113,10 @@ def resolved_query_string_parameters(self) -> Dict[str, List[str]]: This is necessary because different resolvers use different formats to encode multi query string parameters. """ - if self.query_string_parameters is not None: - query_string = {key: value.split(",") for key, value in self.query_string_parameters.items()} - return query_string - - return {} + return {k: v.split(",") for k, v in self.query_string_parameters.items()} @property - def resolved_headers_field(self) -> Dict[str, Any]: + def resolved_headers_field(self) -> MutableMapping[str, str]: """ This property determines the appropriate header to be used as a trusted source for validating OpenAPI. @@ -172,101 +164,6 @@ def http_method(self) -> str: """The HTTP method used. Valid values include: DELETE, GET, HEAD, OPTIONS, PATCH, POST, and PUT.""" return self["httpMethod"] - @overload - def get_query_string_value(self, name: str, default_value: str) -> str: ... - - @overload - def get_query_string_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]: ... - - def get_query_string_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]: - """Get query string value by name - - Parameters - ---------- - name: str - Query string parameter name - default_value: str, optional - Default value if no value was found by name - Returns - ------- - str, optional - Query string parameter value - """ - return get_query_string_value( - query_string_parameters=self.query_string_parameters, - name=name, - default_value=default_value, - ) - - def get_multi_value_query_string_values( - self, - name: str, - default_values: Optional[List[str]] = None, - ) -> List[str]: - """Get multi-value query string parameter values by name - - Parameters - ---------- - name: str - Multi-Value query string parameter name - default_values: List[str], optional - Default values is no values are found by name - Returns - ------- - List[str], optional - List of query string values - - """ - return get_multi_value_query_string_values( - multi_value_query_string_parameters=self.multi_value_query_string_parameters, - name=name, - default_values=default_values, - ) - - @overload - def get_header_value( - self, - name: str, - default_value: str, - case_sensitive: bool = False, - ) -> str: ... - - @overload - def get_header_value( - self, - name: str, - default_value: Optional[str] = None, - case_sensitive: bool = False, - ) -> Optional[str]: ... - - def get_header_value( - self, - name: str, - default_value: Optional[str] = None, - case_sensitive: bool = False, - ) -> Optional[str]: - """Get header value by name - - Parameters - ---------- - name: str - Header name - default_value: str, optional - Default value if no value was found by name - case_sensitive: bool - Whether to use a case-sensitive look up. By default we make a case-insensitive lookup. - Returns - ------- - str, optional - Header value - """ - return get_header_value( - headers=self.headers, - name=name, - default_value=default_value, - case_sensitive=case_sensitive, - ) - def header_serializer(self) -> BaseHeadersSerializer: raise NotImplementedError() diff --git a/aws_lambda_powertools/utilities/data_classes/dynamo_db_stream_event.py b/aws_lambda_powertools/utilities/data_classes/dynamo_db_stream_event.py index d0d1bd7ab41..e2eaa8572a5 100644 --- a/aws_lambda_powertools/utilities/data_classes/dynamo_db_stream_event.py +++ b/aws_lambda_powertools/utilities/data_classes/dynamo_db_stream_event.py @@ -27,7 +27,7 @@ def __init__(self, data: Dict[str, Any]): super().__init__(data) self._deserializer = TypeDeserializer() - def _deserialize_dynamodb_dict(self, key: str) -> Optional[Dict[str, Any]]: + def _deserialize_dynamodb_dict(self, key: str) -> Dict[str, Any]: """Deserialize DynamoDB records available in `Keys`, `NewImage`, and `OldImage` Parameters @@ -37,13 +37,10 @@ def _deserialize_dynamodb_dict(self, key: str) -> Optional[Dict[str, Any]]: Returns ------- - Optional[Dict[str, Any]] + Dict[str, Any] Deserialized records in Python native types """ - dynamodb_dict = self._data.get(key) - if dynamodb_dict is None: - return None - + dynamodb_dict = self._data.get(key) or {} return {k: self._deserializer.deserialize(v) for k, v in dynamodb_dict.items()} @property @@ -53,17 +50,17 @@ def approximate_creation_date_time(self) -> Optional[int]: return None if item is None else int(item) @property - def keys(self) -> Optional[Dict[str, Any]]: # type: ignore[override] + def keys(self) -> Dict[str, Any]: # type: ignore[override] """The primary key attribute(s) for the DynamoDB item that was modified.""" return self._deserialize_dynamodb_dict("Keys") @property - def new_image(self) -> Optional[Dict[str, Any]]: + def new_image(self) -> Dict[str, Any]: """The item in the DynamoDB table as it appeared after it was modified.""" return self._deserialize_dynamodb_dict("NewImage") @property - def old_image(self) -> Optional[Dict[str, Any]]: + def old_image(self) -> Dict[str, Any]: """The item in the DynamoDB table as it appeared before it was modified.""" return self._deserialize_dynamodb_dict("OldImage") @@ -132,9 +129,9 @@ def event_version(self) -> Optional[str]: return self.get("eventVersion") @property - def user_identity(self) -> Optional[dict]: + def user_identity(self) -> dict: """Contains details about the type of identity that made the request""" - return self.get("userIdentity") + return self.get("userIdentity") or {} class DynamoDBStreamEvent(DictWrapper): diff --git a/aws_lambda_powertools/utilities/data_classes/kafka_event.py b/aws_lambda_powertools/utilities/data_classes/kafka_event.py index f20c5254730..32d20095063 100644 --- a/aws_lambda_powertools/utilities/data_classes/kafka_event.py +++ b/aws_lambda_powertools/utilities/data_classes/kafka_event.py @@ -1,11 +1,10 @@ import base64 from functools import cached_property -from typing import Any, Dict, Iterator, List, Optional, overload +from typing import Any, Dict, Iterator, List, MutableMapping, Optional + +from requests.structures import CaseInsensitiveDict from aws_lambda_powertools.utilities.data_classes.common import DictWrapper -from aws_lambda_powertools.utilities.data_classes.shared_functions import ( - get_header_value, -) class KafkaEventRecord(DictWrapper): @@ -64,40 +63,10 @@ def headers(self) -> List[Dict[str, List[int]]]: """The raw Kafka record headers.""" return self["headers"] - @property - def decoded_headers(self) -> Dict[str, bytes]: + @cached_property + def decoded_headers(self) -> MutableMapping[str, bytes]: """Decodes the headers as a single dictionary.""" - return {k: bytes(v) for chunk in self.headers for k, v in chunk.items()} - - @overload - def get_header_value( - self, - name: str, - default_value: str, - case_sensitive: bool = True, - ) -> str: ... - - @overload - def get_header_value( - self, - name: str, - default_value: Optional[str] = None, - case_sensitive: bool = True, - ) -> Optional[str]: ... - - def get_header_value( - self, - name: str, - default_value: Optional[str] = None, - case_sensitive: bool = True, - ) -> Optional[str]: - """Get a decoded header value by name.""" - return get_header_value( - headers=self.decoded_headers, - name=name, - default_value=default_value, - case_sensitive=case_sensitive, - ) + return CaseInsensitiveDict({k: bytes(v) for chunk in self.headers for k, v in chunk.items()}) class KafkaEvent(DictWrapper): diff --git a/aws_lambda_powertools/utilities/data_classes/s3_batch_operation_event.py b/aws_lambda_powertools/utilities/data_classes/s3_batch_operation_event.py index 9c742e0c553..5419f6f8088 100644 --- a/aws_lambda_powertools/utilities/data_classes/s3_batch_operation_event.py +++ b/aws_lambda_powertools/utilities/data_classes/s3_batch_operation_event.py @@ -147,9 +147,9 @@ def get_id(self) -> str: return self["id"] @property - def user_arguments(self) -> Optional[Dict[str, str]]: + def user_arguments(self) -> Dict[str, str]: """Get user arguments provided for this job (only for invocation schema 2.0)""" - return self.get("userArguments") + return self.get("userArguments") or {} class S3BatchOperationTask(DictWrapper): diff --git a/aws_lambda_powertools/utilities/data_classes/s3_object_event.py b/aws_lambda_powertools/utilities/data_classes/s3_object_event.py index dc79b72766f..98fca7dcaf0 100644 --- a/aws_lambda_powertools/utilities/data_classes/s3_object_event.py +++ b/aws_lambda_powertools/utilities/data_classes/s3_object_event.py @@ -1,9 +1,8 @@ -from typing import Dict, Optional, overload +from typing import MutableMapping, Optional + +from requests.structures import CaseInsensitiveDict from aws_lambda_powertools.utilities.data_classes.common import DictWrapper -from aws_lambda_powertools.utilities.data_classes.shared_functions import ( - get_header_value, -) class S3ObjectContext(DictWrapper): @@ -65,52 +64,13 @@ def url(self) -> str: return self["url"] @property - def headers(self) -> Dict[str, str]: + def headers(self) -> MutableMapping[str, str]: """A map of string to strings containing the HTTP headers and their values from the original call, excluding any authorization-related headers. If the same header appears multiple times, their values are combined into a comma-delimited list. The case of the original headers is retained in this map.""" - return self["headers"] - - @overload - def get_header_value( - self, - name: str, - default_value: str, - case_sensitive: bool = False, - ) -> str: ... - - @overload - def get_header_value( - self, - name: str, - default_value: Optional[str] = None, - case_sensitive: bool = False, - ) -> Optional[str]: ... - - def get_header_value( - self, - name: str, - default_value: Optional[str] = None, - case_sensitive: bool = False, - ) -> Optional[str]: - """Get header value by name - - Parameters - ---------- - name: str - Header name - default_value: str, optional - Default value if no value was found by name - case_sensitive: bool - Whether to use a case-sensitive look up - Returns - ------- - str, optional - Header value - """ - return get_header_value(self.headers, name, default_value, case_sensitive) + return CaseInsensitiveDict(self["headers"]) class S3ObjectSessionIssuer(DictWrapper): diff --git a/aws_lambda_powertools/utilities/data_classes/ses_event.py b/aws_lambda_powertools/utilities/data_classes/ses_event.py index 2ebc02e22a0..5adcf7149ee 100644 --- a/aws_lambda_powertools/utilities/data_classes/ses_event.py +++ b/aws_lambda_powertools/utilities/data_classes/ses_event.py @@ -46,24 +46,24 @@ def subject(self) -> str: return str(self["subject"]) @property - def cc(self) -> Optional[List[str]]: + def cc(self) -> List[str]: """The values in the CC header of the email.""" - return self.get("cc") + return self.get("cc") or [] @property - def bcc(self) -> Optional[List[str]]: + def bcc(self) -> List[str]: """The values in the BCC header of the email.""" - return self.get("bcc") + return self.get("bcc") or [] @property - def sender(self) -> Optional[List[str]]: + def sender(self) -> List[str]: """The values in the Sender header of the email.""" - return self.get("sender") + return self.get("sender") or [] @property - def reply_to(self) -> Optional[List[str]]: + def reply_to(self) -> List[str]: """The values in the replyTo header of the email.""" - return self.get("replyTo") + return self.get("replyTo") or [] class SESMail(DictWrapper): diff --git a/aws_lambda_powertools/utilities/data_classes/shared_functions.py b/aws_lambda_powertools/utilities/data_classes/shared_functions.py index 0e88a5dac93..4f3451714a1 100644 --- a/aws_lambda_powertools/utilities/data_classes/shared_functions.py +++ b/aws_lambda_powertools/utilities/data_classes/shared_functions.py @@ -1,7 +1,4 @@ -from __future__ import annotations - import base64 -from typing import Any, Dict, overload def base64_decode(value: str) -> str: @@ -19,129 +16,3 @@ def base64_decode(value: str) -> str: The decoded string value. """ return base64.b64decode(value).decode("UTF-8") - - -@overload -def get_header_value( - headers: dict[str, Any], - name: str, - default_value: str, - case_sensitive: bool = False, -) -> str: ... - - -@overload -def get_header_value( - headers: dict[str, Any], - name: str, - default_value: str | None = None, - case_sensitive: bool = False, -) -> str | None: ... - - -def get_header_value( - headers: dict[str, Any], - name: str, - default_value: str | None = None, - case_sensitive: bool = False, -) -> str | None: - """ - Get the value of a header by its name. - - Parameters - ---------- - headers: Dict[str, str] - The dictionary of headers. - name: str - The name of the header to retrieve. - default_value: str, optional - The default value to return if the header is not found. Default is None. - case_sensitive: bool, optional - Indicates whether the header name should be case-sensitive. Default is False. - - Returns - ------- - str, optional - The value of the header if found, otherwise the default value or None. - """ - # If headers is NoneType, return default value - if not headers: - return default_value - - if case_sensitive: - return headers.get(name, default_value) - name_lower = name.lower() - - return next( - # Iterate over the dict and do a case-insensitive key comparison - (value for key, value in headers.items() if key.lower() == name_lower), - # Default value is returned if no matches was found - default_value, - ) - - -@overload -def get_query_string_value( - query_string_parameters: Dict[str, str] | None, - name: str, - default_value: str, -) -> str: ... - - -@overload -def get_query_string_value( - query_string_parameters: Dict[str, str] | None, - name: str, - default_value: str | None = None, -) -> str | None: ... - - -def get_query_string_value( - query_string_parameters: Dict[str, str] | None, - name: str, - default_value: str | None = None, -) -> str | None: - """ - Retrieves the value of a query string parameter specified by the given name. - - Parameters - ---------- - name: str - The name of the query string parameter to retrieve. - default_value: str, optional - The default value to return if the parameter is not found. Defaults to None. - - Returns - ------- - str. optional - The value of the query string parameter if found, or the default value if not found. - """ - params = query_string_parameters - return default_value if params is None else params.get(name, default_value) - - -def get_multi_value_query_string_values( - multi_value_query_string_parameters: Dict[str, list[str]] | None, - name: str, - default_values: list[str] | None = None, -) -> list[str]: - """ - Retrieves the values of a multi-value string parameters specified by the given name. - - Parameters - ---------- - name: str - The name of the query string parameter to retrieve. - default_value: list[str], optional - The default value to return if the parameter is not found. Defaults to None. - - Returns - ------- - List[str]. optional - The values of the query string parameter if found, or the default values if not found. - """ - - default = default_values or [] - params = multi_value_query_string_parameters or {} - - return params.get(name) or default diff --git a/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py b/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py index c28977c56ba..4f45f7fcf3a 100644 --- a/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py +++ b/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py @@ -1,16 +1,14 @@ from functools import cached_property -from typing import Any, Dict, Optional, overload +from typing import Any, Dict, MutableMapping, Optional + +from requests.structures import CaseInsensitiveDict from aws_lambda_powertools.shared.headers_serializer import ( BaseHeadersSerializer, HttpApiHeadersSerializer, ) from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent, DictWrapper -from aws_lambda_powertools.utilities.data_classes.shared_functions import ( - base64_decode, - get_header_value, - get_query_string_value, -) +from aws_lambda_powertools.utilities.data_classes.shared_functions import base64_decode class VPCLatticeEventBase(BaseProxyEvent): @@ -25,9 +23,9 @@ def json_body(self) -> Any: return self._json_deserializer(self.decoded_body) @property - def headers(self) -> Dict[str, str]: + def headers(self) -> MutableMapping[str, str]: """The VPC Lattice event headers.""" - return self["headers"] + return CaseInsensitiveDict(self["headers"]) @property def decoded_body(self) -> str: @@ -47,76 +45,6 @@ def http_method(self) -> str: """The HTTP method used. Valid values include: DELETE, GET, HEAD, OPTIONS, PATCH, POST, and PUT.""" return self["method"] - @overload - def get_query_string_value(self, name: str, default_value: str) -> str: ... - - @overload - def get_query_string_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]: ... - - def get_query_string_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]: - """Get query string value by name - - Parameters - ---------- - name: str - Query string parameter name - default_value: str, optional - Default value if no value was found by name - Returns - ------- - str, optional - Query string parameter value - """ - return get_query_string_value( - query_string_parameters=self.query_string_parameters, - name=name, - default_value=default_value, - ) - - @overload - def get_header_value( - self, - name: str, - default_value: str, - case_sensitive: bool = False, - ) -> str: ... - - @overload - def get_header_value( - self, - name: str, - default_value: Optional[str] = None, - case_sensitive: bool = False, - ) -> Optional[str]: ... - - def get_header_value( - self, - name: str, - default_value: Optional[str] = None, - case_sensitive: bool = False, - ) -> Optional[str]: - """Get header value by name - - Parameters - ---------- - name: str - Header name - default_value: str, optional - Default value if no value was found by name - case_sensitive: bool - Whether to use a case-sensitive look up - Returns - ------- - str, optional - Header value - """ - return get_header_value( - headers=self.headers, - name=name, - default_value=default_value, - case_sensitive=case_sensitive, - ) - def header_serializer(self) -> BaseHeadersSerializer: # When using the VPC Lattice integration, we have multiple HTTP Headers. return HttpApiHeadersSerializer() @@ -144,13 +72,9 @@ def query_string_parameters(self) -> Dict[str, str]: """The request query string parameters.""" return self["query_string_parameters"] - @property - def resolved_headers_field(self) -> Dict[str, Any]: - if self.headers is not None: - headers = {key.lower(): value.split(",") if "," in value else value for key, value in self.headers.items()} - return headers - - return {} + @cached_property + def resolved_headers_field(self) -> MutableMapping[str, Any]: + return CaseInsensitiveDict({k: v.split(",") if "," in v else v for k, v in self.headers.items()}) class vpcLatticeEventV2Identity(DictWrapper): @@ -259,21 +183,11 @@ def request_context(self) -> vpcLatticeEventV2RequestContext: return vpcLatticeEventV2RequestContext(self["requestContext"]) @property - def query_string_parameters(self) -> Optional[Dict[str, str]]: + def query_string_parameters(self) -> Dict[str, str]: """The request query string parameters. For VPC Lattice V2, the queryStringParameters will contain a Dict[str, List[str]] so to keep compatibility with existing utilities, we merge all the values with a comma. """ - params = self.get("queryStringParameters") - if params: - return {key: ",".join(value) for key, value in params.items()} - else: - return None - - @property - def resolved_headers_field(self) -> Dict[str, str]: - if self.headers is not None: - return {key.lower(): value for key, value in self.headers.items()} - - return {} + params = self.get("queryStringParameters") or {} + return {k: ",".join(v) for k, v in params.items()} diff --git a/docs/core/event_handler/api_gateway.md b/docs/core/event_handler/api_gateway.md index aa667f5f169..3c182b30e4e 100644 --- a/docs/core/event_handler/api_gateway.md +++ b/docs/core/event_handler/api_gateway.md @@ -466,7 +466,7 @@ That is why you see `app.resolve(event, context)` in every example. This allows #### Query strings and payload -Within `app.current_event` property, you can access all available query strings as a dictionary via `query_string_parameters`, or a specific one via `get_query_string_value` method. +Within `app.current_event` property, you can access all available query strings as a dictionary via `query_string_parameters`. You can access the raw payload via `body` property, or if it's a JSON string you can quickly deserialize it via `json_body` property - like the earlier example in the [HTTP Methods](#http-methods) section. @@ -476,7 +476,7 @@ You can access the raw payload via `body` property, or if it's a JSON string you #### Headers -Similarly to [Query strings](#query-strings-and-payload), you can access headers as dictionary via `app.current_event.headers`, or by name via `get_header_value`. If you prefer a case-insensitive lookup of the header value, the `app.current_event.get_header_value` function automatically handles it. +Similarly to [Query strings](#query-strings-and-payload), you can access headers as dictionary via `app.current_event.headers`. Specifically for headers, it's a case-insensitive dictionary, so all lookups are case-insensitive. ```python hl_lines="19" title="Accessing HTTP Headers" --8<-- "examples/event_handler_rest/src/accessing_request_details_headers.py" diff --git a/docs/utilities/data_classes.md b/docs/utilities/data_classes.md index 0b43f36933e..b481fe7b3a7 100644 --- a/docs/utilities/data_classes.md +++ b/docs/utilities/data_classes.md @@ -175,7 +175,7 @@ Use **`APIGatewayAuthorizerRequestEvent`** for type `REQUEST` and **`APIGatewayA @event_source(data_class=APIGatewayAuthorizerRequestEvent) def handler(event: APIGatewayAuthorizerRequestEvent, context): - user = get_user_by_token(event.get_header_value("Authorization")) + user = get_user_by_token(event.headers["Authorization"]) if user is None: # No user was found @@ -263,7 +263,7 @@ See also [this blog post](https://aws.amazon.com/blogs/compute/introducing-iam-a @event_source(data_class=APIGatewayAuthorizerEventV2) def handler(event: APIGatewayAuthorizerEventV2, context): - user = get_user_by_token(event.get_header_value("x-token")) + user = get_user_by_token(event.headers["x-token"]) if user is None: # No user was found, so we return not authorized @@ -397,7 +397,7 @@ In this example, we also use the new Logger `correlation_id` and built-in `corre event: AppSyncResolverEvent = AppSyncResolverEvent(event) # Case insensitive look up of request headers - x_forwarded_for = event.get_header_value("x-forwarded-for") + x_forwarded_for = event.headers.get("x-forwarded-for") # Support for AppSyncIdentityCognito or AppSyncIdentityIAM identity types assert isinstance(event.identity, AppSyncIdentityCognito) diff --git a/examples/event_handler_graphql/src/custom_models.py b/examples/event_handler_graphql/src/custom_models.py index 61e03318d14..63dd1ea8c05 100644 --- a/examples/event_handler_graphql/src/custom_models.py +++ b/examples/event_handler_graphql/src/custom_models.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional from aws_lambda_powertools import Logger, Tracer from aws_lambda_powertools.event_handler import AppSyncResolver @@ -25,12 +25,12 @@ class Location(TypedDict, total=False): class MyCustomModel(AppSyncResolverEvent): @property - def country_viewer(self) -> str: - return self.get_header_value(name="cloudfront-viewer-country", default_value="", case_sensitive=False) + def country_viewer(self) -> Optional[str]: + return self.request_headers.get("cloudfront-viewer-country") @property - def api_key(self) -> str: - return self.get_header_value(name="x-api-key", default_value="", case_sensitive=False) + def api_key(self) -> Optional[str]: + return self.request_headers.get("x-api-key") @app.resolver(type_name="Query", field_name="listLocations") diff --git a/examples/event_handler_rest/src/accessing_request_details.py b/examples/event_handler_rest/src/accessing_request_details.py index 037b76daa66..e9a5d924017 100644 --- a/examples/event_handler_rest/src/accessing_request_details.py +++ b/examples/event_handler_rest/src/accessing_request_details.py @@ -16,12 +16,12 @@ @app.get("/todos") @tracer.capture_method def get_todos(): - todo_id: str = app.current_event.get_query_string_value(name="id", default_value="") + todo_id: str = app.current_event.query_string_parameters["id"] # alternatively _: Optional[str] = app.current_event.query_string_parameters.get("id") # or multi-value query string parameters; ?category="red"&?category="blue" - _: List[str] = app.current_event.get_multi_value_query_string_values(name="category") + _: List[str] = app.current_event.multi_value_query_string_parameters["category"] # Payload _: Optional[str] = app.current_event.body # raw str | None diff --git a/examples/event_handler_rest/src/accessing_request_details_headers.py b/examples/event_handler_rest/src/accessing_request_details_headers.py index f6bfb88c869..de5df2fed0b 100644 --- a/examples/event_handler_rest/src/accessing_request_details_headers.py +++ b/examples/event_handler_rest/src/accessing_request_details_headers.py @@ -16,7 +16,7 @@ def get_todos(): endpoint = "https://jsonplaceholder.typicode.com/todos" - api_key: str = app.current_event.get_header_value(name="X-Api-Key", case_sensitive=True, default_value="") + api_key = app.current_event.headers["X-Api-Key"] todos: Response = requests.get(endpoint, headers={"X-Api-Key": api_key}) todos.raise_for_status() diff --git a/examples/event_handler_rest/src/exception_handling.py b/examples/event_handler_rest/src/exception_handling.py index ea325bd6dc1..24c14bb868d 100644 --- a/examples/event_handler_rest/src/exception_handling.py +++ b/examples/event_handler_rest/src/exception_handling.py @@ -31,7 +31,7 @@ def handle_invalid_limit_qs(ex: ValueError): # receives exception raised def get_todos(): # educational purpose only: we should receive a `ValueError` # if a query string value for `limit` cannot be coerced to int - max_results: int = int(app.current_event.get_query_string_value(name="limit", default_value=0)) + max_results = int(app.current_event.query_string_parameters.get("limit", 0)) todos: requests.Response = requests.get(f"https://jsonplaceholder.typicode.com/todos?limit={max_results}") todos.raise_for_status() diff --git a/examples/event_handler_rest/src/middleware_extending_middlewares.py b/examples/event_handler_rest/src/middleware_extending_middlewares.py index e492caacf47..ad448c03d30 100644 --- a/examples/event_handler_rest/src/middleware_extending_middlewares.py +++ b/examples/event_handler_rest/src/middleware_extending_middlewares.py @@ -22,10 +22,7 @@ def __init__(self, header: str): # (1)! def handler(self, app: APIGatewayRestResolver, next_middleware: NextMiddleware) -> Response: # (2)! request_id = app.current_event.request_context.request_id - correlation_id = app.current_event.get_header_value( - name=self.header, - default_value=request_id, - ) + correlation_id = app.current_event.headers.get(self.header, request_id) response = next_middleware(app) # (3)! response.headers[self.header] = correlation_id diff --git a/examples/event_handler_rest/src/middleware_global_middlewares_module.py b/examples/event_handler_rest/src/middleware_global_middlewares_module.py index 2b06bc31c71..96745a28448 100644 --- a/examples/event_handler_rest/src/middleware_global_middlewares_module.py +++ b/examples/event_handler_rest/src/middleware_global_middlewares_module.py @@ -34,7 +34,7 @@ def inject_correlation_id(app: APIGatewayRestResolver, next_middleware: NextMidd def enforce_correlation_id(app: APIGatewayRestResolver, next_middleware: NextMiddleware) -> Response: # If missing mandatory header raise an error - if not app.current_event.get_header_value("x-correlation-id", case_sensitive=False): + if not app.current_event.headers.get("x-correlation-id"): return Response(status_code=400, body="Correlation ID header is now mandatory.") # (1)! # Get the response from the next middleware and return it diff --git a/examples/event_handler_rest/src/split_route_module.py b/examples/event_handler_rest/src/split_route_module.py index b6a91b3fb3b..b67d5d0568b 100644 --- a/examples/event_handler_rest/src/split_route_module.py +++ b/examples/event_handler_rest/src/split_route_module.py @@ -13,7 +13,7 @@ @router.get("/todos") @tracer.capture_method def get_todos(): - api_key: str = router.current_event.get_header_value(name="X-Api-Key", case_sensitive=True, default_value="") + api_key = router.current_event.headers["X-Api-Key"] todos: Response = requests.get(endpoint, headers={"X-Api-Key": api_key}) todos.raise_for_status() @@ -25,11 +25,7 @@ def get_todos(): @router.get("/todos/") @tracer.capture_method def get_todo_by_id(todo_id: str): # value come as str - api_key: str = router.current_event.get_header_value( - name="X-Api-Key", - case_sensitive=True, - default_value="", - ) # noqa: E501 + api_key = router.current_event.headers["X-Api-Key"] todos: Response = requests.get(f"{endpoint}/{todo_id}", headers={"X-Api-Key": api_key}) todos.raise_for_status() diff --git a/examples/event_handler_rest/src/split_route_prefix_module.py b/examples/event_handler_rest/src/split_route_prefix_module.py index aa17e0cd347..c112a772c6e 100644 --- a/examples/event_handler_rest/src/split_route_prefix_module.py +++ b/examples/event_handler_rest/src/split_route_prefix_module.py @@ -13,7 +13,7 @@ @router.get("/") @tracer.capture_method def get_todos(): - api_key: str = router.current_event.get_header_value(name="X-Api-Key", case_sensitive=True, default_value="") + api_key = router.current_event.headers["X-Api-Key"] todos: Response = requests.get(endpoint, headers={"X-Api-Key": api_key}) todos.raise_for_status() @@ -25,11 +25,7 @@ def get_todos(): @router.get("/") @tracer.capture_method def get_todo_by_id(todo_id: str): # value come as str - api_key: str = router.current_event.get_header_value( - name="X-Api-Key", - case_sensitive=True, - default_value="", - ) # sentinel typing # noqa: E501 + api_key = router.current_event.headers["X-Api-Key"] todos: Response = requests.get(f"{endpoint}/{todo_id}", headers={"X-Api-Key": api_key}) todos.raise_for_status() diff --git a/tests/functional/event_handler/test_api_middlewares.py b/tests/functional/event_handler/test_api_middlewares.py index 58bec259072..ed5c3ecb21b 100644 --- a/tests/functional/event_handler/test_api_middlewares.py +++ b/tests/functional/event_handler/test_api_middlewares.py @@ -484,10 +484,7 @@ def __init__(self, header: str): def handler(self, app: ApiGatewayResolver, get_response: NextMiddleware, **kwargs) -> Response: request_id = app.current_event.request_context.request_id # type: ignore[attr-defined] # using REST event in a base Resolver # noqa: E501 - correlation_id = app.current_event.get_header_value( - name=self.header, - default_value=request_id, - ) # noqa: E501 + correlation_id = app.current_event.headers.get(self.header, request_id) response = get_response(app, **kwargs) response.headers[self.header] = correlation_id diff --git a/tests/unit/data_classes/test_alb_event.py b/tests/unit/data_classes/test_alb_event.py index 47048ab9407..6945dc67c36 100644 --- a/tests/unit/data_classes/test_alb_event.py +++ b/tests/unit/data_classes/test_alb_event.py @@ -14,6 +14,6 @@ def test_alb_event(): assert parsed_event.multi_value_query_string_parameters == raw_event.get("multiValueQueryStringParameters", {}) - assert parsed_event.multi_value_headers == raw_event.get("multiValueHeaders") + assert parsed_event.multi_value_headers == (raw_event.get("multiValueHeaders") or {}) assert parsed_event.body == raw_event["body"] assert parsed_event.is_base64_encoded == raw_event["isBase64Encoded"] diff --git a/tests/unit/data_classes/test_api_gateway_authorizer_event.py b/tests/unit/data_classes/test_api_gateway_authorizer_event.py index 2c5f170d924..4ae44643474 100644 --- a/tests/unit/data_classes/test_api_gateway_authorizer_event.py +++ b/tests/unit/data_classes/test_api_gateway_authorizer_event.py @@ -52,16 +52,16 @@ def test_api_gateway_authorizer_v2(): assert parsed_event.path_parameters == raw_event["pathParameters"] assert parsed_event.stage_variables == raw_event["stageVariables"] - assert parsed_event.get_header_value("Authorization") == "value" - assert parsed_event.get_header_value("authorization") == "value" - assert parsed_event.get_header_value("missing") is None + assert parsed_event.headers["Authorization"] == "value" + assert parsed_event.headers["authorization"] == "value" + assert parsed_event.headers.get("missing") is None # Check for optionals event_optionals = APIGatewayAuthorizerEventV2({"requestContext": {}}) - assert event_optionals.identity_source is None + assert event_optionals.identity_source == [] assert event_optionals.request_context.authentication is None - assert event_optionals.path_parameters is None - assert event_optionals.stage_variables is None + assert event_optionals.path_parameters == {} + assert event_optionals.stage_variables == {} def test_api_gateway_authorizer_token_event(): @@ -90,7 +90,7 @@ def test_api_gateway_authorizer_request_event(): assert parsed_event.path == raw_event["path"] assert parsed_event.http_method == raw_event["httpMethod"] assert parsed_event.headers == raw_event["headers"] - assert parsed_event.get_header_value("accept") == "*/*" + assert parsed_event.headers["accept"] == "*/*" assert parsed_event.query_string_parameters == raw_event["queryStringParameters"] assert parsed_event.path_parameters == raw_event["pathParameters"] assert parsed_event.stage_variables == raw_event["stageVariables"] diff --git a/tests/unit/data_classes/test_api_gateway_proxy_event.py b/tests/unit/data_classes/test_api_gateway_proxy_event.py index d86e4b5e19b..42925ee9c9f 100644 --- a/tests/unit/data_classes/test_api_gateway_proxy_event.py +++ b/tests/unit/data_classes/test_api_gateway_proxy_event.py @@ -54,8 +54,8 @@ def test_default_api_gateway_proxy_event(): assert identity.user_arn == identity_raw["userArn"] assert identity.client_cert.subject_dn == "www.example.com" - assert parsed_event.path_parameters == raw_event["pathParameters"] - assert parsed_event.stage_variables == raw_event["stageVariables"] + assert parsed_event.path_parameters == (raw_event["pathParameters"] or {}) + assert parsed_event.stage_variables == (raw_event["stageVariables"] or {}) assert parsed_event.body == raw_event["body"] assert parsed_event.is_base64_encoded == raw_event["isBase64Encoded"] @@ -121,8 +121,8 @@ def test_api_gateway_proxy_event(): assert identity.user_arn == identity_raw["userArn"] assert identity.client_cert.subject_dn == "www.example.com" - assert parsed_event.path_parameters == raw_event["pathParameters"] - assert parsed_event.stage_variables == raw_event["stageVariables"] + assert parsed_event.path_parameters == (raw_event["pathParameters"] or {}) + assert parsed_event.stage_variables == (raw_event["stageVariables"] or {}) assert parsed_event.body == raw_event["body"] assert parsed_event.is_base64_encoded == raw_event["isBase64Encoded"] diff --git a/tests/unit/data_classes/test_appsync_resolver_event.py b/tests/unit/data_classes/test_appsync_resolver_event.py index a1a010c251a..d235c235fbe 100644 --- a/tests/unit/data_classes/test_appsync_resolver_event.py +++ b/tests/unit/data_classes/test_appsync_resolver_event.py @@ -17,19 +17,19 @@ def test_appsync_resolver_event(): assert parsed_event.arguments.get("name") == raw_event["arguments"]["name"] assert parsed_event.identity.claims.get("token_use") == raw_event["identity"]["claims"]["token_use"] assert parsed_event.source.get("name") == raw_event["source"]["name"] - assert parsed_event.get_header_value("X-amzn-trace-id") == "Root=1-60488877-0b0c4e6727ab2a1c545babd0" - assert parsed_event.get_header_value("X-amzn-trace-id", case_sensitive=True) is None - assert parsed_event.get_header_value("missing", default_value="Foo") == "Foo" + assert parsed_event.request_headers["X-amzn-trace-id"] == "Root=1-60488877-0b0c4e6727ab2a1c545babd0" + assert parsed_event.request_headers["x-amzn-trace-id"] == "Root=1-60488877-0b0c4e6727ab2a1c545babd0" + assert parsed_event.request_headers.get("missing", "Foo") == "Foo" assert parsed_event.prev_result == {} - assert parsed_event.stash is None + assert parsed_event.stash == {} info = parsed_event.info assert info is not None assert isinstance(info, AppSyncResolverEventInfo) assert info.field_name == raw_event["fieldName"] assert info.parent_type_name == raw_event["typeName"] - assert info.variables is None - assert info.selection_set_list is None + assert info.variables == {} + assert info.selection_set_list == [] assert info.selection_set_graphql is None assert isinstance(parsed_event.identity, AppSyncIdentityCognito) @@ -80,17 +80,16 @@ def test_appsync_resolver_direct(): raw_event = load_event("appSyncDirectResolver.json") parsed_event = AppSyncResolverEvent(raw_event) - assert parsed_event.source is None + assert parsed_event.source == {} assert parsed_event.arguments.get("id") == raw_event["arguments"]["id"] assert parsed_event.stash == {} - assert parsed_event.prev_result is None + assert parsed_event.prev_result == {} assert isinstance(parsed_event.identity, AppSyncIdentityCognito) info = parsed_event.info info_raw = raw_event["info"] assert info is not None assert isinstance(info, AppSyncResolverEventInfo) - assert info.selection_set_list is not None assert info.selection_set_list == info["selectionSetList"] assert info.selection_set_graphql == info_raw["selectionSetGraphQL"] assert info.parent_type_name == info_raw["parentTypeName"] @@ -112,7 +111,7 @@ def test_appsync_resolver_event_info(): event = AppSyncResolverEvent(event) - assert event.source is None + assert event.source == {} assert event.identity is None assert event.info is not None assert isinstance(event.info, AppSyncResolverEventInfo) diff --git a/tests/unit/data_classes/test_cloud_watch_alarm_event.py b/tests/unit/data_classes/test_cloud_watch_alarm_event.py index 56933a1505d..df72a7ff1e1 100644 --- a/tests/unit/data_classes/test_cloud_watch_alarm_event.py +++ b/tests/unit/data_classes/test_cloud_watch_alarm_event.py @@ -102,3 +102,4 @@ def test_cloud_watch_alarm_event_composite_metric(): parsed_event.alarm_data.configuration.alarm_actions_suppressor == raw_event["alarmData"]["configuration"]["actionsSuppressor"] ) + assert isinstance(parsed_event.alarm_data.configuration.metrics, List) diff --git a/tests/unit/data_classes/test_cloud_watch_logs_event.py b/tests/unit/data_classes/test_cloud_watch_logs_event.py index c65c55d6334..10a3a499dd0 100644 --- a/tests/unit/data_classes/test_cloud_watch_logs_event.py +++ b/tests/unit/data_classes/test_cloud_watch_logs_event.py @@ -24,7 +24,7 @@ def test_cloud_watch_trigger_event(): assert log_event.get_id == "eventId1" assert log_event.timestamp == 1440442987000 assert log_event.message == "[ERROR] First test message" - assert log_event.extracted_fields is None + assert log_event.extracted_fields == {} event2 = CloudWatchLogsEvent(load_event("cloudWatchLogEvent.json")) assert parsed_event.raw_event == event2.raw_event @@ -52,7 +52,7 @@ def test_cloud_watch_trigger_event_with_policy_level(): assert log_event.get_id == "eventId1" assert log_event.timestamp == 1440442987000 assert log_event.message == "[ERROR] First test message" - assert log_event.extracted_fields is None + assert log_event.extracted_fields == {} event2 = CloudWatchLogsEvent(load_event("cloudWatchLogEventWithPolicyLevel.json")) assert parsed_event.raw_event == event2.raw_event diff --git a/tests/unit/data_classes/test_code_pipeline_job_event.py b/tests/unit/data_classes/test_code_pipeline_job_event.py index a1689ede2f1..75e68b44396 100644 --- a/tests/unit/data_classes/test_code_pipeline_job_event.py +++ b/tests/unit/data_classes/test_code_pipeline_job_event.py @@ -93,8 +93,8 @@ def test_code_pipeline_event_missing_user_parameters(): configuration = parsed_event.data.action_configuration.configuration decoded_params = configuration.decoded_user_parameters assert decoded_params == parsed_event.decoded_user_parameters - assert decoded_params is None - assert configuration.decoded_user_parameters is None + assert decoded_params == {} + assert configuration.decoded_user_parameters == {} def test_code_pipeline_event_non_json_user_parameters(): diff --git a/tests/unit/data_classes/test_cognito_user_pool_event.py b/tests/unit/data_classes/test_cognito_user_pool_event.py index 2321f23c16e..9c4285fd18a 100644 --- a/tests/unit/data_classes/test_cognito_user_pool_event.py +++ b/tests/unit/data_classes/test_cognito_user_pool_event.py @@ -32,8 +32,8 @@ def test_cognito_pre_signup_trigger_event(): # Verify properties user_attributes = parsed_event.request.user_attributes assert user_attributes.get("email") == raw_event["request"]["userAttributes"]["email"] - assert parsed_event.request.validation_data is None - assert parsed_event.request.client_metadata is None + assert parsed_event.request.validation_data == {} + assert parsed_event.request.client_metadata == {} # Verify setters parsed_event.response.auto_confirm_user = True @@ -53,7 +53,7 @@ def test_cognito_post_confirmation_trigger_event(): user_attributes = parsed_event.request.user_attributes assert user_attributes.get("email") == raw_event["request"]["userAttributes"]["email"] - assert parsed_event.request.client_metadata is None + assert parsed_event.request.client_metadata == {} def test_cognito_user_migration_trigger_event(): @@ -63,8 +63,8 @@ def test_cognito_user_migration_trigger_event(): assert parsed_event.trigger_source == raw_event["triggerSource"] assert compare_digest(parsed_event.request.password, raw_event["request"]["password"]) - assert parsed_event.request.validation_data is None - assert parsed_event.request.client_metadata is None + assert parsed_event.request.validation_data == {} + assert parsed_event.request.client_metadata == {} parsed_event.response.user_attributes = {"username": "username"} assert parsed_event.response.user_attributes == raw_event["response"]["userAttributes"] @@ -72,7 +72,7 @@ def test_cognito_user_migration_trigger_event(): assert parsed_event.response.final_user_status is None assert parsed_event.response.message_action is None assert parsed_event.response.force_alias_creation is None - assert parsed_event.response.desired_delivery_mediums is None + assert parsed_event.response.desired_delivery_mediums == [] parsed_event.response.final_user_status = "CONFIRMED" assert parsed_event.response.final_user_status == "CONFIRMED" @@ -93,7 +93,7 @@ def test_cognito_custom_message_trigger_event(): assert parsed_event.request.code_parameter == raw_event["request"]["codeParameter"] assert parsed_event.request.username_parameter == raw_event["request"]["usernameParameter"] assert parsed_event.request.user_attributes.get("phone_number_verified") is False - assert parsed_event.request.client_metadata is None + assert parsed_event.request.client_metadata == {} parsed_event.response.sms_message = "sms" assert parsed_event.response.sms_message == parsed_event["response"]["smsMessage"] @@ -113,7 +113,7 @@ def test_cognito_pre_authentication_trigger_event(): parsed_event["request"]["userNotFound"] = True assert parsed_event.request.user_not_found is True assert parsed_event.request.user_attributes.get("email") == raw_event["request"]["userAttributes"]["email"] - assert parsed_event.request.validation_data is None + assert parsed_event.request.validation_data == {} def test_cognito_post_authentication_trigger_event(): @@ -124,7 +124,7 @@ def test_cognito_post_authentication_trigger_event(): assert parsed_event.request.new_device_used is True assert parsed_event.request.user_attributes.get("email") == raw_event["request"]["userAttributes"]["email"] - assert parsed_event.request.client_metadata is None + assert parsed_event.request.client_metadata == {} def test_cognito_pre_token_generation_trigger_event(): @@ -138,7 +138,7 @@ def test_cognito_pre_token_generation_trigger_event(): assert group_configuration.iam_roles_to_override == [] assert group_configuration.preferred_role is None assert parsed_event.request.user_attributes.get("email") == raw_event["request"]["userAttributes"]["email"] - assert parsed_event.request.client_metadata is None + assert parsed_event.request.client_metadata == {} parsed_event["request"]["groupConfiguration"]["preferredRole"] = "temp" group_configuration = parsed_event.request.group_configuration @@ -148,8 +148,8 @@ def test_cognito_pre_token_generation_trigger_event(): claims_override_details = parsed_event.response.claims_override_details assert parsed_event["response"]["claimsOverrideDetails"] == {} - assert claims_override_details.claims_to_add_or_override is None - assert claims_override_details.claims_to_suppress is None + assert claims_override_details.claims_to_add_or_override == {} + assert claims_override_details.claims_to_suppress == [] assert claims_override_details.group_configuration is None claims_override_details.group_configuration = {} @@ -208,7 +208,7 @@ def test_cognito_define_auth_challenge_trigger_event(): assert session[0].challenge_result is True assert session[0].challenge_metadata is None assert session[1].challenge_metadata == raw_event["request"]["session"][1]["challengeMetadata"] - assert parsed_event.request.client_metadata is None + assert parsed_event.request.client_metadata == {} # Verify setters parsed_event.response.challenge_name = "CUSTOM_CHALLENGE" @@ -236,7 +236,7 @@ def test_create_auth_challenge_trigger_event(): assert len(session) == 1 assert session[0].challenge_name == raw_event["request"]["session"][0]["challengeName"] assert session[0].challenge_metadata == raw_event["request"]["session"][0]["challengeMetadata"] - assert parsed_event.request.client_metadata is None + assert parsed_event.request.client_metadata == {} # Verify setters parsed_event.response.public_challenge_parameters = {"test": "value"} @@ -263,7 +263,6 @@ def test_verify_auth_challenge_response_trigger_event(): == raw_event["request"]["privateChallengeParameters"]["answer"] ) assert parsed_event.request.challenge_answer == raw_event["request"]["challengeAnswer"] - assert parsed_event.request.client_metadata is not None assert parsed_event.request.client_metadata.get("foo") == raw_event["request"]["clientMetadata"]["foo"] assert parsed_event.request.user_not_found is True diff --git a/tests/unit/data_classes/test_dynamo_db_stream_event.py b/tests/unit/data_classes/test_dynamo_db_stream_event.py index f7672abd69b..9632563423a 100644 --- a/tests/unit/data_classes/test_dynamo_db_stream_event.py +++ b/tests/unit/data_classes/test_dynamo_db_stream_event.py @@ -30,7 +30,7 @@ def test_dynamodb_stream_trigger_event(): assert record.event_source == record_raw["eventSource"] assert record.event_source_arn == record_raw["eventSourceARN"] assert record.event_version == record_raw["eventVersion"] - assert record.user_identity is None + assert record.user_identity == {} dynamodb = record.dynamodb assert dynamodb is not None assert dynamodb.approximate_creation_date_time == record_raw["dynamodb"]["ApproximateCreationDateTime"] @@ -38,7 +38,7 @@ def test_dynamodb_stream_trigger_event(): assert keys is not None assert keys["Id"] == decimal_context.create_decimal(101) assert dynamodb.new_image.get("Message") == record_raw["dynamodb"]["NewImage"]["Message"]["S"] - assert dynamodb.old_image is None + assert dynamodb.old_image == {} assert dynamodb.sequence_number == record_raw["dynamodb"]["SequenceNumber"] assert dynamodb.size_bytes == record_raw["dynamodb"]["SizeBytes"] assert dynamodb.stream_view_type == StreamViewType.NEW_AND_OLD_IMAGES @@ -94,7 +94,7 @@ def test_dynamodb_stream_record_deserialization(): def test_dynamodb_stream_record_keys_with_no_keys(): record = StreamRecord({}) - assert record.keys is None + assert record.keys == {} def test_dynamodb_stream_record_keys_overrides_dict_wrapper_keys(): diff --git a/tests/unit/data_classes/test_kafka_event.py b/tests/unit/data_classes/test_kafka_event.py index f97fa8e0a0e..fc36171da77 100644 --- a/tests/unit/data_classes/test_kafka_event.py +++ b/tests/unit/data_classes/test_kafka_event.py @@ -31,7 +31,7 @@ def test_kafka_msk_event(): assert record.value == raw_record["value"] assert record.json_value == {"key": "value"} assert record.decoded_headers == {"headerKey": b"headerValue"} - assert record.get_header_value("HeaderKey", case_sensitive=False) == b"headerValue" + assert record.decoded_headers["HeaderKey"] == b"headerValue" assert parsed_event.record == records[0] @@ -62,7 +62,7 @@ def test_kafka_self_managed_event(): assert record.value == raw_record["value"] assert record.json_value == {"key": "value"} assert record.decoded_headers == {"headerKey": b"headerValue"} - assert record.get_header_value("HeaderKey", case_sensitive=False) == b"headerValue" + assert record.decoded_headers["HeaderKey"] == b"headerValue" assert parsed_event.record == records[0] diff --git a/tests/unit/data_classes/test_lambda_function_url.py b/tests/unit/data_classes/test_lambda_function_url.py index f8ce71b1543..ca8e3d78c59 100644 --- a/tests/unit/data_classes/test_lambda_function_url.py +++ b/tests/unit/data_classes/test_lambda_function_url.py @@ -13,17 +13,17 @@ def test_lambda_function_url_event(): assert parsed_event.path == raw_event["rawPath"] assert parsed_event.raw_query_string == raw_event["rawQueryString"] - assert parsed_event.cookies is None + assert parsed_event.cookies == [] headers = parsed_event.headers assert len(headers) == 20 - assert parsed_event.query_string_parameters is None + assert parsed_event.query_string_parameters == {} assert parsed_event.is_base64_encoded is False assert parsed_event.body is None - assert parsed_event.path_parameters is None - assert parsed_event.stage_variables is None + assert parsed_event.path_parameters == {} + assert parsed_event.stage_variables == {} assert parsed_event.http_method == raw_event["requestContext"]["http"]["method"] request_context = parsed_event.request_context @@ -75,8 +75,8 @@ def test_lambda_function_url_event_iam(): assert parsed_event.is_base64_encoded is False assert parsed_event.body == raw_event["body"] assert parsed_event.decoded_body == raw_event["body"] - assert parsed_event.path_parameters is None - assert parsed_event.stage_variables is None + assert parsed_event.path_parameters == {} + assert parsed_event.stage_variables == {} assert parsed_event.http_method == raw_event["requestContext"]["http"]["method"] request_context = parsed_event.request_context diff --git a/tests/unit/data_classes/test_s3_batch_operation_event.py b/tests/unit/data_classes/test_s3_batch_operation_event.py index ca0d4ae635c..44dc65df07d 100644 --- a/tests/unit/data_classes/test_s3_batch_operation_event.py +++ b/tests/unit/data_classes/test_s3_batch_operation_event.py @@ -19,7 +19,7 @@ def test_s3_batch_operation_schema_v1(): job = parsed_event.job assert job.get_id == raw_event["job"]["id"] - assert job.user_arguments is None + assert job.user_arguments == {} assert parsed_event.invocation_schema_version == raw_event["invocationSchemaVersion"] assert parsed_event.invocation_id == raw_event["invocationId"] diff --git a/tests/unit/data_classes/test_s3_object_event.py b/tests/unit/data_classes/test_s3_object_event.py index 47583d9e544..09d0f14e5f6 100644 --- a/tests/unit/data_classes/test_s3_object_event.py +++ b/tests/unit/data_classes/test_s3_object_event.py @@ -23,7 +23,7 @@ def test_s3_object_event_iam(): user_request = parsed_event.user_request assert user_request.url == raw_event["userRequest"]["url"] assert user_request.headers == raw_event["userRequest"]["headers"] - assert user_request.get_header_value("Accept-Encoding") == "identity" + assert user_request.headers["Accept-Encoding"] == "identity" assert parsed_event.user_identity is not None user_identity = parsed_event.user_identity assert user_identity.get_type == raw_event["userIdentity"]["type"] diff --git a/tests/unit/data_classes/test_ses_event.py b/tests/unit/data_classes/test_ses_event.py index 636cf4cccac..e81c546fb1e 100644 --- a/tests/unit/data_classes/test_ses_event.py +++ b/tests/unit/data_classes/test_ses_event.py @@ -29,10 +29,10 @@ def test_ses_trigger_event(): assert common_headers.to == [expected_address] assert common_headers.message_id == common_headers_raw["messageId"] assert common_headers.subject == common_headers_raw["subject"] - assert common_headers.cc is None - assert common_headers.bcc is None - assert common_headers.sender is None - assert common_headers.reply_to is None + assert common_headers.cc == [] + assert common_headers.bcc == [] + assert common_headers.sender == [] + assert common_headers.reply_to == [] receipt = record.ses.receipt raw_receipt = raw_event["Records"][0]["ses"]["receipt"] assert receipt.timestamp == raw_receipt["timestamp"] diff --git a/tests/unit/data_classes/test_vpc_lattice_event.py b/tests/unit/data_classes/test_vpc_lattice_event.py index ab00c51521f..9f5ad742557 100644 --- a/tests/unit/data_classes/test_vpc_lattice_event.py +++ b/tests/unit/data_classes/test_vpc_lattice_event.py @@ -7,8 +7,8 @@ def test_vpc_lattice_event(): parsed_event = VPCLatticeEvent(raw_event) assert parsed_event.raw_path == raw_event["raw_path"] - assert parsed_event.get_query_string_value("order-id") == "1" - assert parsed_event.get_header_value("user_agent") == "curl/7.64.1" + assert parsed_event.query_string_parameters["order-id"] == "1" + assert parsed_event.headers["user_agent"] == "curl/7.64.1" assert parsed_event.decoded_body == '{"test": "event"}' assert parsed_event.json_body == {"test": "event"} assert parsed_event.method == raw_event["method"] diff --git a/tests/unit/data_classes/test_vpc_lattice_eventv2.py b/tests/unit/data_classes/test_vpc_lattice_eventv2.py index 3726831445f..87a9a69be38 100644 --- a/tests/unit/data_classes/test_vpc_lattice_eventv2.py +++ b/tests/unit/data_classes/test_vpc_lattice_eventv2.py @@ -7,8 +7,8 @@ def test_vpc_lattice_v2_event(): parsed_event = VPCLatticeEventV2(raw_event) assert parsed_event.path == raw_event["path"] - assert parsed_event.get_query_string_value("order-id") == "1" - assert parsed_event.get_header_value("user_agent") == "curl/7.64.1" + assert parsed_event.query_string_parameters["order-id"] == "1" + assert parsed_event.headers["user_agent"] == "curl/7.64.1" assert parsed_event.decoded_body == '{"message": "Hello from Lambda!"}' assert parsed_event.json_body == {"message": "Hello from Lambda!"} assert parsed_event.method == raw_event["method"] diff --git a/tests/unit/test_data_classes.py b/tests/unit/test_data_classes.py index 393bcdf250e..63947eade11 100644 --- a/tests/unit/test_data_classes.py +++ b/tests/unit/test_data_classes.py @@ -240,90 +240,6 @@ def data_property(self) -> str: assert str(event_source) == "{'data_property': '[SENSITIVE]', 'raw_event': '[SENSITIVE]'}" -def test_base_proxy_event_get_query_string_value(): - default_value = "default" - set_value = "value" - - event = BaseProxyEvent({}) - value = event.get_query_string_value("test", default_value) - assert value == default_value - - event._data["queryStringParameters"] = {"test": set_value} - value = event.get_query_string_value("test", default_value) - assert value == set_value - - value = event.get_query_string_value("unknown", default_value) - assert value == default_value - - value = event.get_query_string_value("unknown") - assert value is None - - -def test_base_proxy_event_get_multi_value_query_string_values(): - default_values = ["default_1", "default_2"] - set_values = ["value_1", "value_2"] - - event = BaseProxyEvent({}) - values = event.get_multi_value_query_string_values("test", default_values) - assert values == default_values - - event._data["multiValueQueryStringParameters"] = {"test": set_values} - values = event.get_multi_value_query_string_values("test", default_values) - assert values == set_values - - values = event.get_multi_value_query_string_values("unknown", default_values) - assert values == default_values - - values = event.get_multi_value_query_string_values("unknown") - assert values == [] - - -def test_base_proxy_event_get_header_value(): - default_value = "default" - set_value = "value" - - event = BaseProxyEvent({"headers": {}}) - value = event.get_header_value("test", default_value) - assert value == default_value - - event._data["headers"] = {"test": set_value} - value = event.get_header_value("test", default_value) - assert value == set_value - - # Verify that the default look is case insensitive - value = event.get_header_value("Test") - assert value == set_value - - value = event.get_header_value("unknown", default_value) - assert value == default_value - - value = event.get_header_value("unknown") - assert value is None - - -def test_base_proxy_event_get_header_value_case_insensitive(): - default_value = "default" - set_value = "value" - - event = BaseProxyEvent({"headers": {}}) - - event._data["headers"] = {"Test": set_value} - value = event.get_header_value("test", case_sensitive=True) - assert value is None - - value = event.get_header_value("test", default_value=default_value, case_sensitive=True) - assert value == default_value - - value = event.get_header_value("Test", case_sensitive=True) - assert value == set_value - - value = event.get_header_value("unknown", default_value, case_sensitive=True) - assert value == default_value - - value = event.get_header_value("unknown", case_sensitive=True) - assert value is None - - def test_base_proxy_event_json_body(): data = {"message": "Foo"} event = BaseProxyEvent({"body": json.dumps(data)}) @@ -408,7 +324,7 @@ def test_reflected_types(): def lambda_handler(event: APIGatewayProxyEventV2, _): # THEN we except the event to be of the pass in data class type assert isinstance(event, APIGatewayProxyEventV2) - assert event.get_header_value("x-foo") == "Foo" + assert event.headers["x-foo"] == "Foo" # WHEN calling the lambda handler lambda_handler({"headers": {"X-Foo": "Foo"}}, None)