Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(idempotent): Include function name in the idempotent key #326

Merged
merged 9 commits into from
Mar 12, 2021
10 changes: 6 additions & 4 deletions aws_lambda_powertools/utilities/idempotency/idempotency.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def handle(self) -> Any:
try:
# We call save_inprogress first as an optimization for the most common case where no idempotent record
# already exists. If it succeeds, there's no need to call get_record.
self.persistence_store.save_inprogress(event=self.event)
self.persistence_store.save_inprogress(event=self.event, context=self.context)
except IdempotencyItemAlreadyExistsError:
# Now we know the item already exists, we can retrieve it
record = self._get_idempotency_record()
Expand All @@ -151,7 +151,7 @@ def _get_idempotency_record(self) -> DataRecord:

"""
try:
event_record = self.persistence_store.get_record(self.event)
event_record = self.persistence_store.get_record(event=self.event, context=self.context)
except IdempotencyItemNotFoundError:
# This code path will only be triggered if the record is removed between save_inprogress and get_record.
logger.debug(
Expand Down Expand Up @@ -219,7 +219,9 @@ def _call_lambda_handler(self) -> Any:
# We need these nested blocks to preserve lambda handler exception in case the persistence store operation
# also raises an exception
try:
self.persistence_store.delete_record(event=self.event, exception=handler_exception)
self.persistence_store.delete_record(
event=self.event, context=self.context, exception=handler_exception
)
except Exception as delete_exception:
raise IdempotencyPersistenceLayerError(
"Failed to delete record from idempotency store"
Expand All @@ -228,7 +230,7 @@ def _call_lambda_handler(self) -> Any:

else:
try:
self.persistence_store.save_success(event=self.event, result=handler_response)
self.persistence_store.save_success(event=self.event, context=self.context, result=handler_response)
except Exception as save_exception:
raise IdempotencyPersistenceLayerError(
"Failed to update record state to success in idempotency store"
Expand Down
40 changes: 25 additions & 15 deletions aws_lambda_powertools/utilities/idempotency/persistence/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
IdempotencyKeyError,
IdempotencyValidationError,
)
from aws_lambda_powertools.utilities.typing import LambdaContext

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -152,34 +153,34 @@ def configure(self, config: IdempotencyConfig) -> None:
self._cache = LRUDict(max_items=config.local_cache_max_items)
self.hash_function = getattr(hashlib, config.hash_function)

def _get_hashed_idempotency_key(self, lambda_event: Dict[str, Any]) -> str:
def _get_hashed_idempotency_key(self, event: Dict[str, Any], context: LambdaContext) -> str:
"""
Extract data from lambda event using event key jmespath, and return a hashed representation

Parameters
----------
lambda_event: Dict[str, Any]
event: Dict[str, Any]
Lambda event
context: LambdaContext
Lambda context

Returns
-------
str
Hashed representation of the data extracted by the jmespath expression

"""
data = lambda_event
data = event

if self.event_key_jmespath:
data = self.event_key_compiled_jmespath.search(
lambda_event, options=jmespath.Options(**self.jmespath_options)
)
data = self.event_key_compiled_jmespath.search(event, options=jmespath.Options(**self.jmespath_options))

if self.is_missing_idempotency_key(data):
if self.raise_on_no_idempotency_key:
raise IdempotencyKeyError("No data found to create a hashed idempotency_key")
warnings.warn(f"No value found for idempotency_key. jmespath: {self.event_key_jmespath}")

return self._generate_hash(data)
return f"{context.function_name}#{self._generate_hash(data)}"

@staticmethod
def is_missing_idempotency_key(data) -> bool:
Expand Down Expand Up @@ -298,21 +299,23 @@ def _delete_from_cache(self, idempotency_key: str):
if idempotency_key in self._cache:
del self._cache[idempotency_key]

def save_success(self, event: Dict[str, Any], result: dict) -> None:
def save_success(self, event: Dict[str, Any], context: LambdaContext, result: dict) -> None:
"""
Save record of function's execution completing successfully

Parameters
----------
event: Dict[str, Any]
Lambda event
context: LambdaContext
Lambda context
result: dict
The response from lambda handler
"""
response_data = json.dumps(result, cls=Encoder)

data_record = DataRecord(
idempotency_key=self._get_hashed_idempotency_key(event),
idempotency_key=self._get_hashed_idempotency_key(event, context),
status=STATUS_CONSTANTS["COMPLETED"],
expiry_timestamp=self._get_expiry_timestamp(),
response_data=response_data,
Expand All @@ -326,17 +329,19 @@ def save_success(self, event: Dict[str, Any], result: dict) -> None:

self._save_to_cache(data_record)

def save_inprogress(self, event: Dict[str, Any]) -> None:
def save_inprogress(self, event: Dict[str, Any], context: LambdaContext) -> None:
"""
Save record of function's execution being in progress

Parameters
----------
event: Dict[str, Any]
Lambda event
context: LambdaContext
Lambda context
"""
data_record = DataRecord(
idempotency_key=self._get_hashed_idempotency_key(event),
idempotency_key=self._get_hashed_idempotency_key(event, context),
status=STATUS_CONSTANTS["INPROGRESS"],
expiry_timestamp=self._get_expiry_timestamp(),
payload_hash=self._get_hashed_payload(event),
Expand All @@ -349,18 +354,20 @@ def save_inprogress(self, event: Dict[str, Any]) -> None:

self._put_record(data_record)

def delete_record(self, event: Dict[str, Any], exception: Exception):
def delete_record(self, event: Dict[str, Any], context: LambdaContext, exception: Exception):
"""
Delete record from the persistence store

Parameters
----------
event: Dict[str, Any]
Lambda event
context: LambdaContext
Lambda context
exception
The exception raised by the lambda handler
"""
data_record = DataRecord(idempotency_key=self._get_hashed_idempotency_key(event))
data_record = DataRecord(idempotency_key=self._get_hashed_idempotency_key(event, context))

logger.debug(
f"Lambda raised an exception ({type(exception).__name__}). Clearing in progress record in persistence "
Expand All @@ -370,14 +377,17 @@ def delete_record(self, event: Dict[str, Any], exception: Exception):

self._delete_from_cache(data_record.idempotency_key)

def get_record(self, event: Dict[str, Any]) -> DataRecord:
def get_record(self, event: Dict[str, Any], context: LambdaContext) -> DataRecord:
"""
Calculate idempotency key for lambda_event, then retrieve item from persistence store using idempotency key
and return it as a DataRecord instance.and return it as a DataRecord instance.

Parameters
----------
event: Dict[str, Any]
Lambda event
context: LambdaContext
Lambda context

Returns
-------
Expand All @@ -392,7 +402,7 @@ def get_record(self, event: Dict[str, Any]) -> DataRecord:
Event payload doesn't match the stored record for the given idempotency key
"""

idempotency_key = self._get_hashed_idempotency_key(event)
idempotency_key = self._get_hashed_idempotency_key(event, context)

cached_record = self._retrieve_from_cache(idempotency_key=idempotency_key)
if cached_record:
Expand Down
19 changes: 16 additions & 3 deletions tests/functional/idempotency/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import hashlib
import json
import os
from collections import namedtuple
from decimal import Decimal
from unittest import mock

Expand Down Expand Up @@ -34,6 +35,18 @@ def lambda_apigw_event():
return event


@pytest.fixture
def lambda_context():
lambda_context = {
"function_name": "test-func",
"memory_limit_in_mb": 128,
"invoked_function_arn": "arn:aws:lambda:eu-west-1:809313241234:function:test-func",
"aws_request_id": "52fdfc07-2182-154f-163f-5f0f9a621d72",
}

return namedtuple("LambdaContext", lambda_context.keys())(*lambda_context.values())


@pytest.fixture
def timestamp_future():
return str(int((datetime.datetime.now() + datetime.timedelta(seconds=3600)).timestamp()))
Expand Down Expand Up @@ -132,18 +145,18 @@ def expected_params_put_item_with_validation(hashed_idempotency_key, hashed_vali


@pytest.fixture
def hashed_idempotency_key(lambda_apigw_event, default_jmespath):
def hashed_idempotency_key(lambda_apigw_event, default_jmespath, lambda_context):
compiled_jmespath = jmespath.compile(default_jmespath)
data = compiled_jmespath.search(lambda_apigw_event)
return hashlib.md5(json.dumps(data).encode()).hexdigest()
return "test-func#" + hashlib.md5(json.dumps(data).encode()).hexdigest()


@pytest.fixture
def hashed_idempotency_key_with_envelope(lambda_apigw_event):
event = unwrap_event_from_envelope(
data=lambda_apigw_event, envelope=envelopes.API_GATEWAY_HTTP, jmespath_options={}
)
return hashlib.md5(json.dumps(event).encode()).hexdigest()
return "test-func#" + hashlib.md5(json.dumps(event).encode()).hexdigest()


@pytest.fixture
Expand Down
Loading