diff --git a/fastqueue/services.py b/fastqueue/services.py index cb67a3f..d0d2f6b 100644 --- a/fastqueue/services.py +++ b/fastqueue/services.py @@ -175,6 +175,23 @@ def stats(cls, id: str, session: Session) -> QueueStatsSchema: oldest_unacked_message_age_seconds=oldest_unacked_message_age_seconds, ) + @classmethod + def cleanup(cls, id: str, session: Session) -> None: + queue = cls.get(id=id, session=session) + now = datetime.utcnow() + + expired_at_filter = [Message.queue_id == queue.id, Message.expired_at <= now] + session.query(Message).filter(*expired_at_filter).delete() + + if queue.message_max_deliveries is not None: + delivery_attempts_filter = [ + Message.queue_id == queue.id, + Message.delivery_attempts >= queue.message_max_deliveries, + ] + session.query(Message).filter(*delivery_attempts_filter).delete() + + session.commit() + class MessageService: @classmethod diff --git a/tests/test_services.py b/tests/test_services.py index 2f72110..1ee665b 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -172,6 +172,35 @@ def test_queue_service_stats(session, queue): assert result.oldest_unacked_message_age_seconds == 10 +def test_queue_service_cleanup_expired_at(session, queue): + message1 = MessageFactory(queue_id=queue.id, expired_at=datetime.utcnow() - timedelta(seconds=1)) + message2 = MessageFactory(queue_id=queue.id, expired_at=datetime.utcnow() + timedelta(seconds=1)) + session.add(message1) + session.add(message2) + session.commit() + assert session.query(Message).filter_by(queue_id=queue.id).count() == 2 + + assert QueueService.cleanup(id=queue.id, session=session) is None + assert session.query(Message).filter_by(queue_id=queue.id).count() == 1 + assert session.query(Message).filter_by(queue_id=queue.id).first() == message2 + + +def test_queue_service_cleanup_delivery_attempts(session, queue): + queue.message_max_deliveries = 2 + message1 = MessageFactory(queue_id=queue.id, delivery_attempts=1) + message2 = MessageFactory(queue_id=queue.id, delivery_attempts=2) + message3 = MessageFactory(queue_id=queue.id, delivery_attempts=3) + session.add(message1) + session.add(message2) + session.add(message3) + session.commit() + assert session.query(Message).filter_by(queue_id=queue.id).count() == 3 + + assert QueueService.cleanup(id=queue.id, session=session) is None + assert session.query(Message).filter_by(queue_id=queue.id).count() == 1 + assert session.query(Message).filter_by(queue_id=queue.id).first() == message1 + + @pytest.mark.parametrize( "queue_filters,message_attributes,expected", [