Skip to content

Commit

Permalink
fix(batch): report multiple failures (#967)
Browse files Browse the repository at this point in the history
  • Loading branch information
heitorlessa authored Jan 20, 2022
1 parent 27e5022 commit 70c35b1
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 36 deletions.
29 changes: 17 additions & 12 deletions aws_lambda_powertools/utilities/batch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def _clean(self):
)

messages = self._get_messages_to_report()
self.batch_response = {"batchItemFailures": [messages]}
self.batch_response = {"batchItemFailures": messages}

def _has_messages_to_report(self) -> bool:
if self.fail_messages:
Expand All @@ -397,7 +397,7 @@ def _has_messages_to_report(self) -> bool:
def _entire_batch_failed(self) -> bool:
return len(self.exceptions) == len(self.records)

def _get_messages_to_report(self) -> Dict[str, str]:
def _get_messages_to_report(self) -> List[Dict[str, str]]:
"""
Format messages to use in batch deletion
"""
Expand All @@ -406,20 +406,25 @@ def _get_messages_to_report(self) -> Dict[str, str]:
# Event Source Data Classes follow python idioms for fields
# while Parser/Pydantic follows the event field names to the latter
def _collect_sqs_failures(self):
if self.model:
return {"itemIdentifier": msg.messageId for msg in self.fail_messages}
return {"itemIdentifier": msg.message_id for msg in self.fail_messages}
failures = []
for msg in self.fail_messages:
msg_id = msg.messageId if self.model else msg.message_id
failures.append({"itemIdentifier": msg_id})
return failures

def _collect_kinesis_failures(self):
if self.model:
# Pydantic model uses int but Lambda poller expects str
return {"itemIdentifier": msg.kinesis.sequenceNumber for msg in self.fail_messages}
return {"itemIdentifier": msg.kinesis.sequence_number for msg in self.fail_messages}
failures = []
for msg in self.fail_messages:
msg_id = msg.kinesis.sequenceNumber if self.model else msg.kinesis.sequence_number
failures.append({"itemIdentifier": msg_id})
return failures

def _collect_dynamodb_failures(self):
if self.model:
return {"itemIdentifier": msg.dynamodb.SequenceNumber for msg in self.fail_messages}
return {"itemIdentifier": msg.dynamodb.sequence_number for msg in self.fail_messages}
failures = []
for msg in self.fail_messages:
msg_id = msg.dynamodb.SequenceNumber if self.model else msg.dynamodb.sequence_number
failures.append({"itemIdentifier": msg_id})
return failures

@overload
def _to_batch_type(self, record: dict, event_type: EventType, model: "BatchTypeModels") -> "BatchTypeModels":
Expand Down
84 changes: 60 additions & 24 deletions tests/functional/test_utilities_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,8 @@ def test_batch_processor_middleware_with_failure(sqs_event_factory, record_handl
# GIVEN
first_record = SQSRecord(sqs_event_factory("fail"))
second_record = SQSRecord(sqs_event_factory("success"))
event = {"Records": [first_record.raw_event, second_record.raw_event]}
third_record = SQSRecord(sqs_event_factory("fail"))
event = {"Records": [first_record.raw_event, second_record.raw_event, third_record.raw_event]}

processor = BatchProcessor(event_type=EventType.SQS)

Expand All @@ -426,7 +427,7 @@ def lambda_handler(event, context):
result = lambda_handler(event, {})

# THEN
assert len(result["batchItemFailures"]) == 1
assert len(result["batchItemFailures"]) == 2


def test_batch_processor_context_success_only(sqs_event_factory, record_handler):
Expand All @@ -453,7 +454,8 @@ def test_batch_processor_context_with_failure(sqs_event_factory, record_handler)
# GIVEN
first_record = SQSRecord(sqs_event_factory("failure"))
second_record = SQSRecord(sqs_event_factory("success"))
records = [first_record.raw_event, second_record.raw_event]
third_record = SQSRecord(sqs_event_factory("fail"))
records = [first_record.raw_event, second_record.raw_event, third_record.raw_event]
processor = BatchProcessor(event_type=EventType.SQS)

# WHEN
Expand All @@ -462,8 +464,10 @@ def test_batch_processor_context_with_failure(sqs_event_factory, record_handler)

# THEN
assert processed_messages[1] == ("success", second_record.body, second_record.raw_event)
assert len(batch.fail_messages) == 1
assert batch.response() == {"batchItemFailures": [{"itemIdentifier": first_record.message_id}]}
assert len(batch.fail_messages) == 2
assert batch.response() == {
"batchItemFailures": [{"itemIdentifier": first_record.message_id}, {"itemIdentifier": third_record.message_id}]
}


def test_batch_processor_kinesis_context_success_only(kinesis_event_factory, kinesis_record_handler):
Expand Down Expand Up @@ -491,8 +495,9 @@ def test_batch_processor_kinesis_context_with_failure(kinesis_event_factory, kin
# GIVEN
first_record = KinesisStreamRecord(kinesis_event_factory("failure"))
second_record = KinesisStreamRecord(kinesis_event_factory("success"))
third_record = KinesisStreamRecord(kinesis_event_factory("failure"))

records = [first_record.raw_event, second_record.raw_event]
records = [first_record.raw_event, second_record.raw_event, third_record.raw_event]
processor = BatchProcessor(event_type=EventType.KinesisDataStreams)

# WHEN
Expand All @@ -501,15 +506,21 @@ def test_batch_processor_kinesis_context_with_failure(kinesis_event_factory, kin

# THEN
assert processed_messages[1] == ("success", b64_to_str(second_record.kinesis.data), second_record.raw_event)
assert len(batch.fail_messages) == 1
assert batch.response() == {"batchItemFailures": [{"itemIdentifier": first_record.kinesis.sequence_number}]}
assert len(batch.fail_messages) == 2
assert batch.response() == {
"batchItemFailures": [
{"itemIdentifier": first_record.kinesis.sequence_number},
{"itemIdentifier": third_record.kinesis.sequence_number},
]
}


def test_batch_processor_kinesis_middleware_with_failure(kinesis_event_factory, kinesis_record_handler):
# GIVEN
first_record = KinesisStreamRecord(kinesis_event_factory("failure"))
second_record = KinesisStreamRecord(kinesis_event_factory("success"))
event = {"Records": [first_record.raw_event, second_record.raw_event]}
third_record = KinesisStreamRecord(kinesis_event_factory("failure"))
event = {"Records": [first_record.raw_event, second_record.raw_event, third_record.raw_event]}

processor = BatchProcessor(event_type=EventType.KinesisDataStreams)

Expand All @@ -521,7 +532,7 @@ def lambda_handler(event, context):
result = lambda_handler(event, {})

# THEN
assert len(result["batchItemFailures"]) == 1
assert len(result["batchItemFailures"]) == 2


def test_batch_processor_dynamodb_context_success_only(dynamodb_event_factory, dynamodb_record_handler):
Expand All @@ -548,7 +559,8 @@ def test_batch_processor_dynamodb_context_with_failure(dynamodb_event_factory, d
# GIVEN
first_record = dynamodb_event_factory("failure")
second_record = dynamodb_event_factory("success")
records = [first_record, second_record]
third_record = dynamodb_event_factory("failure")
records = [first_record, second_record, third_record]
processor = BatchProcessor(event_type=EventType.DynamoDBStreams)

# WHEN
Expand All @@ -557,15 +569,21 @@ def test_batch_processor_dynamodb_context_with_failure(dynamodb_event_factory, d

# THEN
assert processed_messages[1] == ("success", second_record["dynamodb"]["NewImage"]["Message"]["S"], second_record)
assert len(batch.fail_messages) == 1
assert batch.response() == {"batchItemFailures": [{"itemIdentifier": first_record["dynamodb"]["SequenceNumber"]}]}
assert len(batch.fail_messages) == 2
assert batch.response() == {
"batchItemFailures": [
{"itemIdentifier": first_record["dynamodb"]["SequenceNumber"]},
{"itemIdentifier": third_record["dynamodb"]["SequenceNumber"]},
]
}


def test_batch_processor_dynamodb_middleware_with_failure(dynamodb_event_factory, dynamodb_record_handler):
# GIVEN
first_record = dynamodb_event_factory("failure")
second_record = dynamodb_event_factory("success")
event = {"Records": [first_record, second_record]}
third_record = dynamodb_event_factory("failure")
event = {"Records": [first_record, second_record, third_record]}

processor = BatchProcessor(event_type=EventType.DynamoDBStreams)

Expand All @@ -577,7 +595,7 @@ def lambda_handler(event, context):
result = lambda_handler(event, {})

# THEN
assert len(result["batchItemFailures"]) == 1
assert len(result["batchItemFailures"]) == 2


def test_batch_processor_context_model(sqs_event_factory, order_event_factory):
Expand Down Expand Up @@ -639,17 +657,23 @@ def record_handler(record: OrderSqs):
order_event = order_event_factory({"type": "success"})
order_event_fail = order_event_factory({"type": "fail"})
first_record = sqs_event_factory(order_event_fail)
third_record = sqs_event_factory(order_event_fail)
second_record = sqs_event_factory(order_event)
records = [first_record, second_record]
records = [first_record, second_record, third_record]

# WHEN
processor = BatchProcessor(event_type=EventType.SQS, model=OrderSqs)
with processor(records, record_handler) as batch:
batch.process()

# THEN
assert len(batch.fail_messages) == 1
assert batch.response() == {"batchItemFailures": [{"itemIdentifier": first_record["messageId"]}]}
assert len(batch.fail_messages) == 2
assert batch.response() == {
"batchItemFailures": [
{"itemIdentifier": first_record["messageId"]},
{"itemIdentifier": third_record["messageId"]},
]
}


def test_batch_processor_dynamodb_context_model(dynamodb_event_factory, order_event_factory):
Expand Down Expand Up @@ -726,16 +750,22 @@ def record_handler(record: OrderDynamoDBRecord):
order_event_fail = order_event_factory({"type": "fail"})
first_record = dynamodb_event_factory(order_event_fail)
second_record = dynamodb_event_factory(order_event)
records = [first_record, second_record]
third_record = dynamodb_event_factory(order_event_fail)
records = [first_record, second_record, third_record]

# WHEN
processor = BatchProcessor(event_type=EventType.DynamoDBStreams, model=OrderDynamoDBRecord)
with processor(records, record_handler) as batch:
batch.process()

# THEN
assert len(batch.fail_messages) == 1
assert batch.response() == {"batchItemFailures": [{"itemIdentifier": first_record["dynamodb"]["SequenceNumber"]}]}
assert len(batch.fail_messages) == 2
assert batch.response() == {
"batchItemFailures": [
{"itemIdentifier": first_record["dynamodb"]["SequenceNumber"]},
{"itemIdentifier": third_record["dynamodb"]["SequenceNumber"]},
]
}


def test_batch_processor_kinesis_context_parser_model(kinesis_event_factory, order_event_factory):
Expand Down Expand Up @@ -807,16 +837,22 @@ def record_handler(record: OrderKinesisRecord):

first_record = kinesis_event_factory(order_event_fail)
second_record = kinesis_event_factory(order_event)
records = [first_record, second_record]
third_record = kinesis_event_factory(order_event_fail)
records = [first_record, second_record, third_record]

# WHEN
processor = BatchProcessor(event_type=EventType.KinesisDataStreams, model=OrderKinesisRecord)
with processor(records, record_handler) as batch:
batch.process()

# THEN
assert len(batch.fail_messages) == 1
assert batch.response() == {"batchItemFailures": [{"itemIdentifier": first_record["kinesis"]["sequenceNumber"]}]}
assert len(batch.fail_messages) == 2
assert batch.response() == {
"batchItemFailures": [
{"itemIdentifier": first_record["kinesis"]["sequenceNumber"]},
{"itemIdentifier": third_record["kinesis"]["sequenceNumber"]},
]
}


def test_batch_processor_error_when_entire_batch_fails(sqs_event_factory, record_handler):
Expand Down

0 comments on commit 70c35b1

Please sign in to comment.