diff --git a/aws_lambda_powertools/utilities/batch/base.py b/aws_lambda_powertools/utilities/batch/base.py index d8fdc2d85f2..21b59328ef0 100644 --- a/aws_lambda_powertools/utilities/batch/base.py +++ b/aws_lambda_powertools/utilities/batch/base.py @@ -385,7 +385,7 @@ def _clean(self): ) messages = self._get_messages_to_report() - self.batch_response = {"batchItemFailures": [messages]} + self.batch_response = {"batchItemFailures": messages} def _has_messages_to_report(self) -> bool: if self.fail_messages: @@ -397,7 +397,7 @@ def _has_messages_to_report(self) -> bool: def _entire_batch_failed(self) -> bool: return len(self.exceptions) == len(self.records) - def _get_messages_to_report(self) -> Dict[str, str]: + def _get_messages_to_report(self) -> List[Dict[str, str]]: """ Format messages to use in batch deletion """ @@ -406,20 +406,25 @@ def _get_messages_to_report(self) -> Dict[str, str]: # Event Source Data Classes follow python idioms for fields # while Parser/Pydantic follows the event field names to the latter def _collect_sqs_failures(self): - if self.model: - return {"itemIdentifier": msg.messageId for msg in self.fail_messages} - return {"itemIdentifier": msg.message_id for msg in self.fail_messages} + failures = [] + for msg in self.fail_messages: + msg_id = msg.messageId if self.model else msg.message_id + failures.append({"itemIdentifier": msg_id}) + return failures def _collect_kinesis_failures(self): - if self.model: - # Pydantic model uses int but Lambda poller expects str - return {"itemIdentifier": msg.kinesis.sequenceNumber for msg in self.fail_messages} - return {"itemIdentifier": msg.kinesis.sequence_number for msg in self.fail_messages} + failures = [] + for msg in self.fail_messages: + msg_id = msg.kinesis.sequenceNumber if self.model else msg.kinesis.sequence_number + failures.append({"itemIdentifier": msg_id}) + return failures def _collect_dynamodb_failures(self): - if self.model: - return {"itemIdentifier": msg.dynamodb.SequenceNumber for msg in self.fail_messages} - return {"itemIdentifier": msg.dynamodb.sequence_number for msg in self.fail_messages} + failures = [] + for msg in self.fail_messages: + msg_id = msg.dynamodb.SequenceNumber if self.model else msg.dynamodb.sequence_number + failures.append({"itemIdentifier": msg_id}) + return failures @overload def _to_batch_type(self, record: dict, event_type: EventType, model: "BatchTypeModels") -> "BatchTypeModels": diff --git a/tests/functional/test_utilities_batch.py b/tests/functional/test_utilities_batch.py index 3728af3111d..d32a044279b 100644 --- a/tests/functional/test_utilities_batch.py +++ b/tests/functional/test_utilities_batch.py @@ -414,7 +414,8 @@ def test_batch_processor_middleware_with_failure(sqs_event_factory, record_handl # GIVEN first_record = SQSRecord(sqs_event_factory("fail")) second_record = SQSRecord(sqs_event_factory("success")) - event = {"Records": [first_record.raw_event, second_record.raw_event]} + third_record = SQSRecord(sqs_event_factory("fail")) + event = {"Records": [first_record.raw_event, second_record.raw_event, third_record.raw_event]} processor = BatchProcessor(event_type=EventType.SQS) @@ -426,7 +427,7 @@ def lambda_handler(event, context): result = lambda_handler(event, {}) # THEN - assert len(result["batchItemFailures"]) == 1 + assert len(result["batchItemFailures"]) == 2 def test_batch_processor_context_success_only(sqs_event_factory, record_handler): @@ -453,7 +454,8 @@ def test_batch_processor_context_with_failure(sqs_event_factory, record_handler) # GIVEN first_record = SQSRecord(sqs_event_factory("failure")) second_record = SQSRecord(sqs_event_factory("success")) - records = [first_record.raw_event, second_record.raw_event] + third_record = SQSRecord(sqs_event_factory("fail")) + records = [first_record.raw_event, second_record.raw_event, third_record.raw_event] processor = BatchProcessor(event_type=EventType.SQS) # WHEN @@ -462,8 +464,10 @@ def test_batch_processor_context_with_failure(sqs_event_factory, record_handler) # THEN assert processed_messages[1] == ("success", second_record.body, second_record.raw_event) - assert len(batch.fail_messages) == 1 - assert batch.response() == {"batchItemFailures": [{"itemIdentifier": first_record.message_id}]} + assert len(batch.fail_messages) == 2 + assert batch.response() == { + "batchItemFailures": [{"itemIdentifier": first_record.message_id}, {"itemIdentifier": third_record.message_id}] + } def test_batch_processor_kinesis_context_success_only(kinesis_event_factory, kinesis_record_handler): @@ -491,8 +495,9 @@ def test_batch_processor_kinesis_context_with_failure(kinesis_event_factory, kin # GIVEN first_record = KinesisStreamRecord(kinesis_event_factory("failure")) second_record = KinesisStreamRecord(kinesis_event_factory("success")) + third_record = KinesisStreamRecord(kinesis_event_factory("failure")) - records = [first_record.raw_event, second_record.raw_event] + records = [first_record.raw_event, second_record.raw_event, third_record.raw_event] processor = BatchProcessor(event_type=EventType.KinesisDataStreams) # WHEN @@ -501,15 +506,21 @@ def test_batch_processor_kinesis_context_with_failure(kinesis_event_factory, kin # THEN assert processed_messages[1] == ("success", b64_to_str(second_record.kinesis.data), second_record.raw_event) - assert len(batch.fail_messages) == 1 - assert batch.response() == {"batchItemFailures": [{"itemIdentifier": first_record.kinesis.sequence_number}]} + assert len(batch.fail_messages) == 2 + assert batch.response() == { + "batchItemFailures": [ + {"itemIdentifier": first_record.kinesis.sequence_number}, + {"itemIdentifier": third_record.kinesis.sequence_number}, + ] + } def test_batch_processor_kinesis_middleware_with_failure(kinesis_event_factory, kinesis_record_handler): # GIVEN first_record = KinesisStreamRecord(kinesis_event_factory("failure")) second_record = KinesisStreamRecord(kinesis_event_factory("success")) - event = {"Records": [first_record.raw_event, second_record.raw_event]} + third_record = KinesisStreamRecord(kinesis_event_factory("failure")) + event = {"Records": [first_record.raw_event, second_record.raw_event, third_record.raw_event]} processor = BatchProcessor(event_type=EventType.KinesisDataStreams) @@ -521,7 +532,7 @@ def lambda_handler(event, context): result = lambda_handler(event, {}) # THEN - assert len(result["batchItemFailures"]) == 1 + assert len(result["batchItemFailures"]) == 2 def test_batch_processor_dynamodb_context_success_only(dynamodb_event_factory, dynamodb_record_handler): @@ -548,7 +559,8 @@ def test_batch_processor_dynamodb_context_with_failure(dynamodb_event_factory, d # GIVEN first_record = dynamodb_event_factory("failure") second_record = dynamodb_event_factory("success") - records = [first_record, second_record] + third_record = dynamodb_event_factory("failure") + records = [first_record, second_record, third_record] processor = BatchProcessor(event_type=EventType.DynamoDBStreams) # WHEN @@ -557,15 +569,21 @@ def test_batch_processor_dynamodb_context_with_failure(dynamodb_event_factory, d # THEN assert processed_messages[1] == ("success", second_record["dynamodb"]["NewImage"]["Message"]["S"], second_record) - assert len(batch.fail_messages) == 1 - assert batch.response() == {"batchItemFailures": [{"itemIdentifier": first_record["dynamodb"]["SequenceNumber"]}]} + assert len(batch.fail_messages) == 2 + assert batch.response() == { + "batchItemFailures": [ + {"itemIdentifier": first_record["dynamodb"]["SequenceNumber"]}, + {"itemIdentifier": third_record["dynamodb"]["SequenceNumber"]}, + ] + } def test_batch_processor_dynamodb_middleware_with_failure(dynamodb_event_factory, dynamodb_record_handler): # GIVEN first_record = dynamodb_event_factory("failure") second_record = dynamodb_event_factory("success") - event = {"Records": [first_record, second_record]} + third_record = dynamodb_event_factory("failure") + event = {"Records": [first_record, second_record, third_record]} processor = BatchProcessor(event_type=EventType.DynamoDBStreams) @@ -577,7 +595,7 @@ def lambda_handler(event, context): result = lambda_handler(event, {}) # THEN - assert len(result["batchItemFailures"]) == 1 + assert len(result["batchItemFailures"]) == 2 def test_batch_processor_context_model(sqs_event_factory, order_event_factory): @@ -639,8 +657,9 @@ def record_handler(record: OrderSqs): order_event = order_event_factory({"type": "success"}) order_event_fail = order_event_factory({"type": "fail"}) first_record = sqs_event_factory(order_event_fail) + third_record = sqs_event_factory(order_event_fail) second_record = sqs_event_factory(order_event) - records = [first_record, second_record] + records = [first_record, second_record, third_record] # WHEN processor = BatchProcessor(event_type=EventType.SQS, model=OrderSqs) @@ -648,8 +667,13 @@ def record_handler(record: OrderSqs): batch.process() # THEN - assert len(batch.fail_messages) == 1 - assert batch.response() == {"batchItemFailures": [{"itemIdentifier": first_record["messageId"]}]} + assert len(batch.fail_messages) == 2 + assert batch.response() == { + "batchItemFailures": [ + {"itemIdentifier": first_record["messageId"]}, + {"itemIdentifier": third_record["messageId"]}, + ] + } def test_batch_processor_dynamodb_context_model(dynamodb_event_factory, order_event_factory): @@ -726,7 +750,8 @@ def record_handler(record: OrderDynamoDBRecord): order_event_fail = order_event_factory({"type": "fail"}) first_record = dynamodb_event_factory(order_event_fail) second_record = dynamodb_event_factory(order_event) - records = [first_record, second_record] + third_record = dynamodb_event_factory(order_event_fail) + records = [first_record, second_record, third_record] # WHEN processor = BatchProcessor(event_type=EventType.DynamoDBStreams, model=OrderDynamoDBRecord) @@ -734,8 +759,13 @@ def record_handler(record: OrderDynamoDBRecord): batch.process() # THEN - assert len(batch.fail_messages) == 1 - assert batch.response() == {"batchItemFailures": [{"itemIdentifier": first_record["dynamodb"]["SequenceNumber"]}]} + assert len(batch.fail_messages) == 2 + assert batch.response() == { + "batchItemFailures": [ + {"itemIdentifier": first_record["dynamodb"]["SequenceNumber"]}, + {"itemIdentifier": third_record["dynamodb"]["SequenceNumber"]}, + ] + } def test_batch_processor_kinesis_context_parser_model(kinesis_event_factory, order_event_factory): @@ -807,7 +837,8 @@ def record_handler(record: OrderKinesisRecord): first_record = kinesis_event_factory(order_event_fail) second_record = kinesis_event_factory(order_event) - records = [first_record, second_record] + third_record = kinesis_event_factory(order_event_fail) + records = [first_record, second_record, third_record] # WHEN processor = BatchProcessor(event_type=EventType.KinesisDataStreams, model=OrderKinesisRecord) @@ -815,8 +846,13 @@ def record_handler(record: OrderKinesisRecord): batch.process() # THEN - assert len(batch.fail_messages) == 1 - assert batch.response() == {"batchItemFailures": [{"itemIdentifier": first_record["kinesis"]["sequenceNumber"]}]} + assert len(batch.fail_messages) == 2 + assert batch.response() == { + "batchItemFailures": [ + {"itemIdentifier": first_record["kinesis"]["sequenceNumber"]}, + {"itemIdentifier": third_record["kinesis"]["sequenceNumber"]}, + ] + } def test_batch_processor_error_when_entire_batch_fails(sqs_event_factory, record_handler):