From 8e32aa0cd78d9a55d271c0af8e4e17cd3e77e7d2 Mon Sep 17 00:00:00 2001 From: Allisson Azevedo Date: Thu, 15 Sep 2022 18:46:53 -0300 Subject: [PATCH] feat: add dead queue support (#18) --- alembic/versions/001_initial.py | 11 ++++++--- fastqueue/models.py | 3 +++ fastqueue/schemas.py | 3 +++ fastqueue/services.py | 42 ++++++++++++++++++++++++++++++--- tests/factories.py | 1 + tests/test_services.py | 39 ++++++++++++++++++++++++++++++ 6 files changed, 93 insertions(+), 6 deletions(-) diff --git a/alembic/versions/001_initial.py b/alembic/versions/001_initial.py index f5b996b..2731357 100644 --- a/alembic/versions/001_initial.py +++ b/alembic/versions/001_initial.py @@ -1,8 +1,8 @@ """Auto generated -Revision ID: b290fd728c17 +Revision ID: fe94d60449f2 Revises: -Create Date: 2022-09-06 21:44:14.941396 +Create Date: 2022-09-15 17:51:40.259256 """ import sqlalchemy as sa @@ -11,7 +11,7 @@ from alembic import op # revision identifiers, used by Alembic. -revision = "b290fd728c17" +revision = "fe94d60449f2" down_revision = None branch_labels = None depends_on = None @@ -29,12 +29,17 @@ def upgrade() -> None: "queues", sa.Column("id", sa.String(length=128), nullable=False), sa.Column("topic_id", sa.String(length=128), nullable=True), + sa.Column("dead_queue_id", sa.String(length=128), nullable=True), sa.Column("ack_deadline_seconds", sa.Integer(), nullable=False), sa.Column("message_retention_seconds", sa.Integer(), nullable=False), sa.Column("message_filters", postgresql.JSONB(astext_type=sa.Text()), nullable=True), sa.Column("message_max_deliveries", sa.Integer(), nullable=True), sa.Column("created_at", sa.DateTime(), nullable=False), sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["dead_queue_id"], + ["queues.id"], + ), sa.ForeignKeyConstraint( ["topic_id"], ["topics.id"], diff --git a/fastqueue/models.py b/fastqueue/models.py index 861abc8..a236d6a 100644 --- a/fastqueue/models.py +++ b/fastqueue/models.py @@ -21,6 +21,9 @@ class Queue(Base): topic_id = sqlalchemy.Column( sqlalchemy.String(length=128), sqlalchemy.ForeignKey("topics.id"), index=True, nullable=True ) + dead_queue_id = sqlalchemy.Column( + sqlalchemy.String(length=128), sqlalchemy.ForeignKey("queues.id"), nullable=True + ) ack_deadline_seconds = sqlalchemy.Column(sqlalchemy.Integer, nullable=False) message_retention_seconds = sqlalchemy.Column(sqlalchemy.Integer, nullable=False) message_filters = sqlalchemy.Column(postgresql.JSONB, nullable=True) diff --git a/fastqueue/schemas.py b/fastqueue/schemas.py index b509582..647264b 100644 --- a/fastqueue/schemas.py +++ b/fastqueue/schemas.py @@ -32,6 +32,7 @@ class ListTopicSchema(Schema): class CreateQueueSchema(Schema): id: str = Field(..., regex=regex_for_id, max_length=128) topic_id: str | None = Field(None, regex=regex_for_id, max_length=128) + dead_queue_id: str | None = Field(None, regex=regex_for_id, max_length=128) ack_deadline_seconds: int = Field( ..., ge=settings.min_ack_deadline_seconds, le=settings.max_ack_deadline_seconds ) @@ -46,6 +47,7 @@ class CreateQueueSchema(Schema): class UpdateQueueSchema(Schema): topic_id: str | None = Field(None, regex=regex_for_id, max_length=128) + dead_queue_id: str | None = Field(None, regex=regex_for_id, max_length=128) ack_deadline_seconds: int = Field( ..., ge=settings.min_ack_deadline_seconds, le=settings.max_ack_deadline_seconds ) @@ -61,6 +63,7 @@ class UpdateQueueSchema(Schema): class QueueSchema(Schema): id: str topic_id: str | None + dead_queue_id: str | None ack_deadline_seconds: int message_retention_seconds: int message_filters: dict[str, list[str]] | None diff --git a/fastqueue/services.py b/fastqueue/services.py index 243937b..45db407 100644 --- a/fastqueue/services.py +++ b/fastqueue/services.py @@ -96,10 +96,14 @@ def create(cls, data: CreateQueueSchema, session: Session) -> QueueSchema: if data.topic_id is not None: TopicService.get(data.topic_id, session=session) + if data.dead_queue_id is not None: + cls.get(data.dead_queue_id, session=session) + now = datetime.utcnow() queue = Queue( id=data.id, topic_id=data.topic_id, + dead_queue_id=data.dead_queue_id, ack_deadline_seconds=data.ack_deadline_seconds, message_retention_seconds=data.message_retention_seconds, message_filters=data.message_filters, @@ -123,7 +127,11 @@ def update(cls, id: str, data: UpdateQueueSchema, session: Session) -> QueueSche if data.topic_id is not None: TopicService.get(data.topic_id, session=session) + if data.dead_queue_id is not None: + cls.get(data.dead_queue_id, session=session) + queue.topic_id = data.topic_id + queue.dead_queue_id = data.dead_queue_id queue.ack_deadline_seconds = data.ack_deadline_seconds queue.message_retention_seconds = data.message_retention_seconds queue.message_filters = data.message_filters @@ -153,6 +161,7 @@ def list( def delete(cls, id: str, session: Session) -> None: cls.get(id, session=session) session.query(Message).filter_by(queue_id=id).delete() + session.query(Queue).filter_by(dead_queue_id=id).update({"dead_queue_id": None}) session.query(Queue).filter_by(id=id).delete() session.commit() @@ -176,20 +185,47 @@ def stats(cls, id: str, session: Session) -> QueueStatsSchema: ) @classmethod - def cleanup(cls, id: str, session: Session) -> None: - queue = cls.get(id=id, session=session) + def _cleanup_expired_messages(cls, queue: QueueSchema, session: Session) -> None: now = datetime.utcnow() expired_at_filter = [Message.queue_id == queue.id, Message.expired_at <= now] session.query(Message).filter(*expired_at_filter).delete() - if queue.message_max_deliveries is not None: + @classmethod + def _cleanup_delivery_attempts_exceeded_messages(cls, queue: QueueSchema, session: Session) -> None: + if queue.message_max_deliveries is not None and queue.dead_queue_id is None: delivery_attempts_filter = [ Message.queue_id == queue.id, Message.delivery_attempts >= queue.message_max_deliveries, ] session.query(Message).filter(*delivery_attempts_filter).delete() + @classmethod + def _cleanup_move_messages_to_dead_queue(cls, queue: QueueSchema, session: Session) -> None: + if queue.message_max_deliveries is not None and queue.dead_queue_id is not None: + dead_queue = cls.get(id=queue.dead_queue_id, session=session) + delivery_attempts_filter = [ + Message.queue_id == queue.id, + Message.delivery_attempts >= queue.message_max_deliveries, + ] + now = datetime.utcnow() + update_data = { + "queue_id": queue.dead_queue_id, + "delivery_attempts": 0, + "expired_at": now + timedelta(seconds=dead_queue.message_retention_seconds), + "scheduled_at": now, + "updated_at": now, + } + session.query(Message).filter(*delivery_attempts_filter).update(update_data) + + @classmethod + def cleanup(cls, id: str, session: Session) -> None: + queue = cls.get(id=id, session=session) + + cls._cleanup_expired_messages(queue=queue, session=session) + cls._cleanup_delivery_attempts_exceeded_messages(queue=queue, session=session) + cls._cleanup_move_messages_to_dead_queue(queue=queue, session=session) + session.commit() diff --git a/tests/factories.py b/tests/factories.py index a642312..ea157d3 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -25,6 +25,7 @@ class Meta: model = Queue id = factory.Sequence(lambda n: "queue_%s" % n) + dead_queue_id = None ack_deadline_seconds = default_ack_deadline_seconds message_retention_seconds = default_message_retention_seconds message_max_deliveries = None diff --git a/tests/test_services.py b/tests/test_services.py index a78fe3f..98e91e9 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -154,6 +154,19 @@ def test_queue_service_delete(session, queue): assert session.query(Message).filter_by(queue_id=queue.id).count() == 0 +def test_queue_service_delete_dead_queue(session, queue): + dead_queue = QueueFactory() + session.add(dead_queue) + session.commit() + queue.dead_queue_id = dead_queue.id + session.commit() + + assert QueueService.delete(dead_queue.id, session=session) is None + + session.refresh(queue) + assert queue.dead_queue_id is None + + def test_queue_service_delete_not_found(session): with pytest.raises(NotFoundError): QueueService.delete("invalid-queue-name", session=session) @@ -201,6 +214,32 @@ def test_queue_service_cleanup_delivery_attempts(session, queue): assert session.query(Message).filter_by(queue_id=queue.id).first() == message1 +def test_queue_service_cleanup_move_to_dead_queue(session, queue): + dead_queue = QueueFactory() + session.add(dead_queue) + session.commit() + queue.message_max_deliveries = 2 + queue.dead_queue_id = dead_queue.id + message1 = MessageFactory(queue_id=queue.id, delivery_attempts=1) + message2 = MessageFactory(queue_id=queue.id, delivery_attempts=2) + message3 = MessageFactory(queue_id=queue.id, delivery_attempts=3) + session.add(message1) + session.add(message2) + session.add(message3) + session.commit() + assert session.query(Message).filter_by(queue_id=queue.id).count() == 3 + + assert QueueService.cleanup(id=queue.id, session=session) is None + assert session.query(Message).filter_by(queue_id=queue.id).count() == 1 + assert session.query(Message).filter_by(queue_id=queue.id).first() == message1 + + assert session.query(Message).filter_by(queue_id=dead_queue.id).count() == 2 + session.refresh(message2) + session.refresh(message3) + assert message2.delivery_attempts == 0 + assert message3.delivery_attempts == 0 + + @pytest.mark.parametrize( "queue_filters,message_attributes,expected", [