From 6f63e75afcf3701c115d4d5c40fdede72182e6ae Mon Sep 17 00:00:00 2001 From: Jules Robichaud-Gagnon Date: Sun, 15 Oct 2023 21:54:16 -0400 Subject: [PATCH 1/4] Add task_routing_key and task_properties to modify_context_before_task_publish --- django_structlog/celery/receivers.py | 14 ++++++++++++-- test_app/tests/celery/test_receivers.py | 24 ++++++++++++++++++++++-- 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/django_structlog/celery/receivers.py b/django_structlog/celery/receivers.py index 017a7941..32cc2e32 100644 --- a/django_structlog/celery/receivers.py +++ b/django_structlog/celery/receivers.py @@ -6,7 +6,14 @@ logger = structlog.getLogger(__name__) -def receiver_before_task_publish(sender=None, headers=None, body=None, **kwargs): +def receiver_before_task_publish( + sender=None, + headers=None, + body=None, + properties=None, + routing_key=None, + **kwargs, +): import celery if celery.current_app.conf.task_protocol < 2: @@ -17,7 +24,10 @@ def receiver_before_task_publish(sender=None, headers=None, body=None, **kwargs) context["parent_task_id"] = context.pop("task_id") signals.modify_context_before_task_publish.send( - sender=receiver_before_task_publish, context=context + sender=receiver_before_task_publish, + context=context, + task_routing_key=routing_key, + task_properties=properties, ) headers["__django_structlog__"] = context diff --git a/test_app/tests/celery/test_receivers.py b/test_app/tests/celery/test_receivers.py index 3d8dfb79..c18cb6dd 100644 --- a/test_app/tests/celery/test_receivers.py +++ b/test_app/tests/celery/test_receivers.py @@ -102,9 +102,16 @@ def test_signal_modify_context_before_task_publish_celery_protocol_v2(self): expected_uuid = "00000000-0000-0000-0000-000000000000" user_id = "1234" expected_parent_task_uuid = "11111111-1111-1111-1111-111111111111" + routing_key = "foo" + properties = {"correlation_id": "22222222-2222-2222-2222-222222222222"} + + received_properties = None + received_routing_key = None @receiver(signals.modify_context_before_task_publish) - def receiver_modify_context_before_task_publish(sender, signal, context): + def receiver_modify_context_before_task_publish( + sender, signal, context, task_properties, task_routing_key, **kwargs + ): keys_to_keep = {"request_id", "parent_task_id"} new_dict = { key_to_keep: context[key_to_keep] @@ -113,6 +120,10 @@ def receiver_modify_context_before_task_publish(sender, signal, context): } context.clear() context.update(new_dict) + nonlocal received_properties + received_properties = task_properties + nonlocal received_routing_key + received_routing_key = task_routing_key headers = {} structlog.contextvars.bind_contextvars( @@ -120,7 +131,11 @@ def receiver_modify_context_before_task_publish(sender, signal, context): user_id=user_id, task_id=expected_parent_task_uuid, ) - receivers.receiver_before_task_publish(headers=headers) + receivers.receiver_before_task_publish( + headers=headers, + routing_key=routing_key, + properties=properties, + ) self.assertDictEqual( { @@ -132,6 +147,11 @@ def receiver_modify_context_before_task_publish(sender, signal, context): headers, "Only `request_id` and `parent_task_id` are preserved", ) + self.assertDictEqual( + {"correlation_id": "22222222-2222-2222-2222-222222222222"}, + received_properties, + ) + self.assertEqual("foo", received_routing_key) def test_receiver_after_task_publish(self): expected_task_id = "00000000-0000-0000-0000-000000000000" From d7afbcd5c5cadff483dcc56cd63ca9bfff548bfa Mon Sep 17 00:00:00 2001 From: Jules Robichaud-Gagnon Date: Tue, 17 Oct 2023 20:03:48 -0400 Subject: [PATCH 2/4] Refactor receivers to use classes instead in order to use members --- django_structlog/apps.py | 10 +- django_structlog/celery/receivers.py | 232 +++++++++++++----------- django_structlog/celery/steps.py | 27 +-- django_structlog/commands.py | 59 +++--- test_app/tests/celery/test_receivers.py | 130 ++++++++----- test_app/tests/celery/test_steps.py | 46 +---- test_app/tests/test_apps.py | 59 +++++- 7 files changed, 311 insertions(+), 252 deletions(-) diff --git a/django_structlog/apps.py b/django_structlog/apps.py index 50f29584..726205a6 100644 --- a/django_structlog/apps.py +++ b/django_structlog/apps.py @@ -8,11 +8,13 @@ class DjangoStructLogConfig(AppConfig): def ready(self): if app_settings.CELERY_ENABLED: - from .celery.receivers import connect_celery_signals + from .celery.receivers import CeleryReceiver - connect_celery_signals() + self._celery_receiver = CeleryReceiver() + self._celery_receiver.connect_signals() if app_settings.COMMAND_LOGGING_ENABLED: - from .commands import init_command_signals + from .commands import DjangoCommandReceiver - init_command_signals() + self._django_command_receiver = DjangoCommandReceiver() + self._django_command_receiver.connect_signals() diff --git a/django_structlog/celery/receivers.py b/django_structlog/celery/receivers.py index 32cc2e32..706177b8 100644 --- a/django_structlog/celery/receivers.py +++ b/django_structlog/celery/receivers.py @@ -6,117 +6,143 @@ logger = structlog.getLogger(__name__) -def receiver_before_task_publish( - sender=None, - headers=None, - body=None, - properties=None, - routing_key=None, - **kwargs, -): - import celery - - if celery.current_app.conf.task_protocol < 2: - return - - context = structlog.contextvars.get_merged_contextvars(logger) - if "task_id" in context: - context["parent_task_id"] = context.pop("task_id") - - signals.modify_context_before_task_publish.send( - sender=receiver_before_task_publish, - context=context, - task_routing_key=routing_key, - task_properties=properties, - ) - - headers["__django_structlog__"] = context - - -def receiver_after_task_publish(sender=None, headers=None, body=None, **kwargs): - logger.info( - "task_enqueued", - child_task_id=headers.get("id") if headers else body.get("id"), - child_task_name=headers.get("task") if headers else body.get("task"), - ) - - -def receiver_task_pre_run(task_id, task, *args, **kwargs): - structlog.contextvars.clear_contextvars() - structlog.contextvars.bind_contextvars(task_id=task_id) - metadata = getattr(task.request, "__django_structlog__", {}) - structlog.contextvars.bind_contextvars(**metadata) - signals.bind_extra_task_metadata.send( - sender=receiver_task_pre_run, task=task, logger=logger - ) - logger.info("task_started", task=task.name) - - -def receiver_task_retry(request=None, reason=None, einfo=None, **kwargs): - logger.warning("task_retrying", reason=reason) - - -def receiver_task_success(result=None, **kwargs): - signals.pre_task_succeeded.send( - sender=receiver_task_success, logger=logger, result=result - ) - logger.info("task_succeeded") - - -def receiver_task_failure( - task_id=None, - exception=None, - traceback=None, - einfo=None, - sender=None, - *args, - **kwargs, -): - throws = getattr(sender, "throws", ()) - if isinstance(exception, throws): - logger.info( - "task_failed", - error=str(exception), - ) - else: - logger.exception( - "task_failed", - error=str(exception), - exception=exception, +class CeleryReceiver: + def receiver_before_task_publish( + self, + sender=None, + headers=None, + body=None, + properties=None, + routing_key=None, + **kwargs, + ): + import celery + + if celery.current_app.conf.task_protocol < 2: + return + + context = structlog.contextvars.get_merged_contextvars(logger) + if "task_id" in context: + context["parent_task_id"] = context.pop("task_id") + + signals.modify_context_before_task_publish.send( + sender=self.receiver_before_task_publish, + context=context, + task_routing_key=routing_key, + task_properties=properties, ) + headers["__django_structlog__"] = context -def receiver_task_revoked( - request=None, terminated=None, signum=None, expired=None, **kwargs -): - metadata = getattr(request, "__django_structlog__", {}).copy() - metadata["task_id"] = request.id - metadata["task"] = request.task + def receiver_after_task_publish( + self, sender=None, headers=None, body=None, **kwargs + ): + logger.info( + "task_enqueued", + child_task_id=headers.get("id") if headers else body.get("id"), + child_task_name=headers.get("task") if headers else body.get("task"), + ) - logger.warning( - "task_revoked", - terminated=terminated, - signum=signum.value if signum is not None else None, - signame=signum.name if signum is not None else None, - expired=expired, - **metadata, - ) + def receiver_task_prerun(self, task_id, task, *args, **kwargs): + structlog.contextvars.clear_contextvars() + structlog.contextvars.bind_contextvars(task_id=task_id) + metadata = getattr(task.request, "__django_structlog__", {}) + structlog.contextvars.bind_contextvars(**metadata) + signals.bind_extra_task_metadata.send( + sender=self.receiver_task_prerun, task=task, logger=logger + ) + logger.info("task_started", task=task.name) + def receiver_task_retry(self, request=None, reason=None, einfo=None, **kwargs): + logger.warning("task_retrying", reason=reason) -def receiver_task_unknown(message=None, exc=None, name=None, id=None, **kwargs): - logger.error( - "task_not_found", - task=name, - task_id=id, - ) + def receiver_task_success(self, result=None, **kwargs): + signals.pre_task_succeeded.send( + sender=self.receiver_task_success, logger=logger, result=result + ) + logger.info("task_succeeded") + + def receiver_task_failure( + self, + task_id=None, + exception=None, + traceback=None, + einfo=None, + sender=None, + *args, + **kwargs, + ): + throws = getattr(sender, "throws", ()) + if isinstance(exception, throws): + logger.info( + "task_failed", + error=str(exception), + ) + else: + logger.exception( + "task_failed", + error=str(exception), + exception=exception, + ) + + def receiver_task_revoked( + self, request=None, terminated=None, signum=None, expired=None, **kwargs + ): + metadata = getattr(request, "__django_structlog__", {}).copy() + metadata["task_id"] = request.id + metadata["task"] = request.task + + logger.warning( + "task_revoked", + terminated=terminated, + signum=signum.value if signum is not None else None, + signame=signum.name if signum is not None else None, + expired=expired, + **metadata, + ) + def receiver_task_unknown( + self, message=None, exc=None, name=None, id=None, **kwargs + ): + logger.error( + "task_not_found", + task=name, + task_id=id, + ) -def receiver_task_rejected(message=None, exc=None, **kwargs): - logger.exception("task_rejected", task_id=message.properties.get("correlation_id")) + def receiver_task_rejected(self, message=None, exc=None, **kwargs): + logger.exception( + "task_rejected", task_id=message.properties.get("correlation_id") + ) + def connect_signals(self): + from celery.signals import ( + before_task_publish, + after_task_publish, + ) -def connect_celery_signals(): - from celery.signals import before_task_publish, after_task_publish + before_task_publish.connect(self.receiver_before_task_publish) + after_task_publish.connect(self.receiver_after_task_publish) + + def connect_worker_signals(self): + from celery.signals import ( + before_task_publish, + after_task_publish, + task_prerun, + task_retry, + task_success, + task_failure, + task_revoked, + task_unknown, + task_rejected, + ) - before_task_publish.connect(receiver_before_task_publish) - after_task_publish.connect(receiver_after_task_publish) + before_task_publish.connect(self.receiver_before_task_publish) + after_task_publish.connect(self.receiver_after_task_publish) + task_prerun.connect(self.receiver_task_prerun) + task_retry.connect(self.receiver_task_retry) + task_success.connect(self.receiver_task_success) + task_failure.connect(self.receiver_task_failure) + task_revoked.connect(self.receiver_task_revoked) + task_unknown.connect(self.receiver_task_unknown) + task_rejected.connect(self.receiver_task_rejected) diff --git a/django_structlog/celery/steps.py b/django_structlog/celery/steps.py index 85914b51..9deb4a20 100644 --- a/django_structlog/celery/steps.py +++ b/django_structlog/celery/steps.py @@ -1,6 +1,6 @@ from celery import bootsteps -from . import receivers +from .receivers import CeleryReceiver class DjangoStructLogInitStep(bootsteps.Step): @@ -16,26 +16,5 @@ class DjangoStructLogInitStep(bootsteps.Step): def __init__(self, parent, **kwargs): super().__init__(parent, **kwargs) - import celery - from celery.signals import ( - before_task_publish, - after_task_publish, - task_prerun, - task_retry, - task_success, - task_failure, - task_revoked, - ) - - before_task_publish.connect(receivers.receiver_before_task_publish) - after_task_publish.connect(receivers.receiver_after_task_publish) - task_prerun.connect(receivers.receiver_task_pre_run) - task_retry.connect(receivers.receiver_task_retry) - task_success.connect(receivers.receiver_task_success) - task_failure.connect(receivers.receiver_task_failure) - task_revoked.connect(receivers.receiver_task_revoked) - if celery.VERSION > (4,): - from celery.signals import task_unknown, task_rejected - - task_unknown.connect(receivers.receiver_task_unknown) - task_rejected.connect(receivers.receiver_task_rejected) + self.receiver = CeleryReceiver() + self.receiver.connect_worker_signals() diff --git a/django_structlog/commands.py b/django_structlog/commands.py index 4a3096eb..d1d1e966 100644 --- a/django_structlog/commands.py +++ b/django_structlog/commands.py @@ -2,39 +2,40 @@ import uuid logger = structlog.getLogger(__name__) -stack = [] -def pre_receiver(sender, *args, **kwargs): - command_id = str(uuid.uuid4()) - if len(stack): - parent_command_id, _ = stack[-1] - tokens = structlog.contextvars.bind_contextvars( - parent_command_id=parent_command_id, command_id=command_id +class DjangoCommandReceiver: + def __init__(self): + self.stack = [] + + def pre_receiver(self, sender, *args, **kwargs): + command_id = str(uuid.uuid4()) + if len(self.stack): + parent_command_id, _ = self.stack[-1] + tokens = structlog.contextvars.bind_contextvars( + parent_command_id=parent_command_id, command_id=command_id + ) + else: + tokens = structlog.contextvars.bind_contextvars(command_id=command_id) + self.stack.append((command_id, tokens)) + + logger.info( + "command_started", + command_name=sender.__module__.replace(".management.commands", ""), ) - else: - tokens = structlog.contextvars.bind_contextvars(command_id=command_id) - stack.append((command_id, tokens)) - logger.info( - "command_started", - command_name=sender.__module__.replace(".management.commands", ""), - ) + def post_receiver(self, sender, outcome, *args, **kwargs): + logger.info("command_finished") + if len(self.stack): # pragma: no branch + command_id, tokens = self.stack.pop() + structlog.contextvars.reset_contextvars(**tokens) -def post_receiver(sender, outcome, *args, **kwargs): - logger.info("command_finished") + def connect_signals(self): + try: + from django_extensions.management.signals import pre_command, post_command + except ModuleNotFoundError: # pragma: no cover + return - if len(stack): - command_id, tokens = stack.pop() - structlog.contextvars.reset_contextvars(**tokens) - - -def init_command_signals(): - try: - from django_extensions.management.signals import pre_command, post_command - except ModuleNotFoundError: # pragma: no cover - return - - pre_command.connect(pre_receiver) - post_command.connect(post_receiver) + pre_command.connect(self.pre_receiver) + post_command.connect(self.post_receiver) diff --git a/test_app/tests/celery/test_receivers.py b/test_app/tests/celery/test_receivers.py index c18cb6dd..69b528b1 100644 --- a/test_app/tests/celery/test_receivers.py +++ b/test_app/tests/celery/test_receivers.py @@ -5,7 +5,7 @@ import structlog from celery import shared_task from django.contrib.auth.models import AnonymousUser -from django.dispatch import receiver +from django.dispatch import receiver as django_receiver from django.test import TestCase, RequestFactory from django_structlog.celery import receivers, signals @@ -29,20 +29,13 @@ def test_defer_task(self): def test_task(value): # pragma: no cover pass - from celery.signals import before_task_publish, after_task_publish - - before_task_publish.connect(receivers.receiver_before_task_publish) - after_task_publish.connect(receivers.receiver_after_task_publish) - try: - structlog.contextvars.bind_contextvars(request_id=expected_uuid) - with self.assertLogs( - logging.getLogger("django_structlog.celery.receivers"), logging.INFO - ) as log_results: - test_task.delay("foo") - finally: - before_task_publish.disconnect(receivers.receiver_before_task_publish) - after_task_publish.disconnect(receivers.receiver_after_task_publish) - + receiver = receivers.CeleryReceiver() + receiver.connect_signals() + structlog.contextvars.bind_contextvars(request_id=expected_uuid) + with self.assertLogs( + logging.getLogger("django_structlog.celery.receivers"), logging.INFO + ) as log_results: + test_task.delay("foo") self.assertEqual(1, len(log_results.records)) record = log_results.records[0] self.assertEqual("task_enqueued", record.msg["event"]) @@ -61,7 +54,8 @@ def test_receiver_before_task_publish_celery_protocol_v2(self): user_id=expected_user_id, task_id=expected_parent_task_uuid, ) - receivers.receiver_before_task_publish(headers=headers) + receiver = receivers.CeleryReceiver() + receiver.receiver_before_task_publish(headers=headers) self.assertDictEqual( { @@ -90,8 +84,9 @@ def test_receiver_before_task_publish_celery_protocol_v1(self): mock_conf = MagicMock() mock_current_app.conf = mock_conf mock_conf.task_protocol = 1 + receiver = receivers.CeleryReceiver() with patch("celery.current_app", mock_current_app): - receivers.receiver_before_task_publish(headers=headers) + receiver.receiver_before_task_publish(headers=headers) self.assertDictEqual( {}, @@ -108,7 +103,7 @@ def test_signal_modify_context_before_task_publish_celery_protocol_v2(self): received_properties = None received_routing_key = None - @receiver(signals.modify_context_before_task_publish) + @django_receiver(signals.modify_context_before_task_publish) def receiver_modify_context_before_task_publish( sender, signal, context, task_properties, task_routing_key, **kwargs ): @@ -131,7 +126,8 @@ def receiver_modify_context_before_task_publish( user_id=user_id, task_id=expected_parent_task_uuid, ) - receivers.receiver_before_task_publish( + receiver = receivers.CeleryReceiver() + receiver.receiver_before_task_publish( headers=headers, routing_key=routing_key, properties=properties, @@ -157,11 +153,11 @@ def test_receiver_after_task_publish(self): expected_task_id = "00000000-0000-0000-0000-000000000000" expected_task_name = "Foo" headers = {"id": expected_task_id, "task": expected_task_name} - + receiver = receivers.CeleryReceiver() with self.assertLogs( logging.getLogger("django_structlog.celery.receivers"), logging.INFO ) as log_results: - receivers.receiver_after_task_publish(headers=headers) + receiver.receiver_after_task_publish(headers=headers) self.assertEqual(1, len(log_results.records)) record = log_results.records[0] @@ -177,10 +173,11 @@ def test_receiver_after_task_publish_celery_3(self): expected_task_name = "Foo" body = {"id": expected_task_id, "task": expected_task_name} + receiver = receivers.CeleryReceiver() with self.assertLogs( logging.getLogger("django_structlog.celery.receivers"), logging.INFO ) as log_results: - receivers.receiver_after_task_publish(body=body) + receiver.receiver_after_task_publish(body=body) self.assertEqual(1, len(log_results.records)) record = log_results.records[0] @@ -206,11 +203,11 @@ def test_receiver_task_pre_run(self): context = structlog.contextvars.get_merged_contextvars(self.logger) self.assertDictEqual({"foo": "bar"}, context) - + receiver = receivers.CeleryReceiver() with self.assertLogs( logging.getLogger("django_structlog.celery.receivers"), logging.INFO ) as log_results: - receivers.receiver_task_pre_run(task_id, task) + receiver.receiver_task_prerun(task_id, task) context = structlog.contextvars.get_merged_contextvars(self.logger) self.assertDictEqual( @@ -230,7 +227,7 @@ def test_receiver_task_pre_run(self): self.assertEqual("task_name", record.msg["task"]) def test_signal_bind_extra_task_metadata(self): - @receiver(signals.bind_extra_task_metadata) + @django_receiver(signals.bind_extra_task_metadata) def receiver_bind_extra_request_metadata( sender, signal, task=None, logger=None ): @@ -247,8 +244,8 @@ def receiver_bind_extra_request_metadata( context = structlog.contextvars.get_merged_contextvars(self.logger) self.assertDictEqual({"foo": "bar"}, context) - - receivers.receiver_task_pre_run(task_id, task) + receiver = receivers.CeleryReceiver() + receiver.receiver_task_prerun(task_id, task) context = structlog.contextvars.get_merged_contextvars(self.logger) self.assertEqual(context["correlation_id"], expected_correlation_uuid) @@ -257,10 +254,11 @@ def receiver_bind_extra_request_metadata( def test_receiver_task_retry(self): expected_reason = "foo" + receiver = receivers.CeleryReceiver() with self.assertLogs( logging.getLogger("django_structlog.celery.receivers"), logging.WARNING ) as log_results: - receivers.receiver_task_retry(reason=expected_reason) + receiver.receiver_task_retry(reason=expected_reason) self.assertEqual(1, len(log_results.records)) record = log_results.records[0] @@ -272,16 +270,17 @@ def test_receiver_task_retry(self): def test_receiver_task_success(self): expected_result = "foo" - @receiver(signals.pre_task_succeeded) + @django_receiver(signals.pre_task_succeeded) def receiver_pre_task_succeeded( sender, signal, task=None, logger=None, result=None ): structlog.contextvars.bind_contextvars(result=result) + receiver = receivers.CeleryReceiver() with self.assertLogs( logging.getLogger("django_structlog.celery.receivers"), logging.INFO ) as log_results: - receivers.receiver_task_success(result=expected_result) + receiver.receiver_task_success(result=expected_result) self.assertEqual(1, len(log_results.records)) record = log_results.records[0] @@ -292,11 +291,11 @@ def receiver_pre_task_succeeded( def test_receiver_task_failure(self): expected_exception = "foo" - + receiver = receivers.CeleryReceiver() with self.assertLogs( logging.getLogger("django_structlog.celery.receivers"), logging.ERROR ) as log_results: - receivers.receiver_task_failure(exception=Exception("foo")) + receiver.receiver_task_failure(exception=Exception("foo")) self.assertEqual(1, len(log_results.records)) record = log_results.records[0] @@ -311,11 +310,11 @@ def test_receiver_task_failure_with_throws(self): mock_sender = Mock() mock_sender.throws = (Exception,) - + receiver = receivers.CeleryReceiver() with self.assertLogs( logging.getLogger("django_structlog.celery.receivers"), logging.INFO ) as log_results: - receivers.receiver_task_failure( + receiver.receiver_task_failure( exception=Exception("foo"), sender=mock_sender ) @@ -339,10 +338,12 @@ def test_receiver_task_revoked(self): } request.task = expected_task_name request.id = task_id + + receiver = receivers.CeleryReceiver() with self.assertLogs( logging.getLogger("django_structlog.celery.receivers"), logging.WARNING ) as log_results: - receivers.receiver_task_revoked( + receiver.receiver_task_revoked( request=request, terminated=False, signum=None, expired=False ) @@ -379,10 +380,12 @@ def test_receiver_task_revoked_terminated(self): } request.task = expected_task_name request.id = task_id + + receiver = receivers.CeleryReceiver() with self.assertLogs( logging.getLogger("django_structlog.celery.receivers"), logging.WARNING ) as log_results: - receivers.receiver_task_revoked( + receiver.receiver_task_revoked( request=request, terminated=True, signum=SIGTERM, expired=False ) @@ -411,10 +414,11 @@ def test_receiver_task_unknown(self): task_id = "11111111-1111-1111-1111-111111111111" expected_task_name = "task_name" + receiver = receivers.CeleryReceiver() with self.assertLogs( logging.getLogger("django_structlog.celery.receivers"), logging.ERROR ) as log_results: - receivers.receiver_task_unknown(id=task_id, name=expected_task_name) + receiver.receiver_task_unknown(id=task_id, name=expected_task_name) self.assertEqual(1, len(log_results.records)) record = log_results.records[0] @@ -430,10 +434,11 @@ def test_receiver_task_rejected(self): message = Mock(name="message") message.properties = dict(correlation_id=task_id) + receiver = receivers.CeleryReceiver() with self.assertLogs( logging.getLogger("django_structlog.celery.receivers"), logging.ERROR ) as log_results: - receivers.receiver_task_rejected(message=message) + receiver.receiver_task_rejected(message=message) self.assertEqual(1, len(log_results.records)) record = log_results.records[0] @@ -443,22 +448,57 @@ def test_receiver_task_rejected(self): self.assertEqual(task_id, record.msg["task_id"]) +class TestConnectCeleryTaskSignals(TestCase): + def test_call(self): + from celery.signals import ( + before_task_publish, + after_task_publish, + task_prerun, + task_retry, + task_success, + task_failure, + task_revoked, + task_unknown, + task_rejected, + ) + + from django_structlog.celery.receivers import CeleryReceiver + + receiver = CeleryReceiver() + with patch( + "celery.utils.dispatch.signal.Signal.connect", autospec=True + ) as mock_connect: + receiver.connect_worker_signals() + + mock_connect.assert_has_calls( + [ + call(before_task_publish, receiver.receiver_before_task_publish), + call(after_task_publish, receiver.receiver_after_task_publish), + call(task_prerun, receiver.receiver_task_prerun), + call(task_retry, receiver.receiver_task_retry), + call(task_success, receiver.receiver_task_success), + call(task_failure, receiver.receiver_task_failure), + call(task_revoked, receiver.receiver_task_revoked), + call(task_unknown, receiver.receiver_task_unknown), + call(task_rejected, receiver.receiver_task_rejected), + ] + ) + + class TestConnectCelerySignals(TestCase): def test_call(self): from celery.signals import before_task_publish, after_task_publish - from django_structlog.celery.receivers import ( - receiver_before_task_publish, - receiver_after_task_publish, - ) + from django_structlog.celery.receivers import CeleryReceiver + receiver = CeleryReceiver() with patch( "celery.utils.dispatch.signal.Signal.connect", autospec=True ) as mock_connect: - receivers.connect_celery_signals() + receiver.connect_signals() mock_connect.assert_has_calls( [ - call(before_task_publish, receiver_before_task_publish), - call(after_task_publish, receiver_after_task_publish), + call(before_task_publish, receiver.receiver_before_task_publish), + call(after_task_publish, receiver.receiver_after_task_publish), ] ) diff --git a/test_app/tests/celery/test_steps.py b/test_app/tests/celery/test_steps.py index 5cd36254..d323987a 100644 --- a/test_app/tests/celery/test_steps.py +++ b/test_app/tests/celery/test_steps.py @@ -1,4 +1,4 @@ -from unittest.mock import patch, call +from unittest.mock import patch from django.test import TestCase @@ -7,44 +7,12 @@ class TestDjangoStructLogInitStep(TestCase): def test_call(self): - from celery.signals import ( - before_task_publish, - after_task_publish, - task_prerun, - task_retry, - task_success, - task_failure, - task_revoked, - task_unknown, - task_rejected, - ) - from django_structlog.celery.receivers import ( - receiver_before_task_publish, - receiver_after_task_publish, - receiver_task_pre_run, - receiver_task_retry, - receiver_task_success, - receiver_task_failure, - receiver_task_revoked, - receiver_task_unknown, - receiver_task_rejected, - ) - with patch( - "celery.utils.dispatch.signal.Signal.connect", autospec=True + "django_structlog.celery.receivers.CeleryReceiver.connect_worker_signals", + autospec=True, ) as mock_connect: - steps.DjangoStructLogInitStep(None) + step = steps.DjangoStructLogInitStep(None) + + mock_connect.assert_called_once() - mock_connect.assert_has_calls( - [ - call(before_task_publish, receiver_before_task_publish), - call(after_task_publish, receiver_after_task_publish), - call(task_prerun, receiver_task_pre_run), - call(task_retry, receiver_task_retry), - call(task_success, receiver_task_success), - call(task_failure, receiver_task_failure), - call(task_revoked, receiver_task_revoked), - call(task_unknown, receiver_task_unknown), - call(task_rejected, receiver_task_rejected), - ] - ) + self.assertIsNotNone(step.receiver) diff --git a/test_app/tests/test_apps.py b/test_app/tests/test_apps.py index 03c8e891..5efbe8bc 100644 --- a/test_app/tests/test_apps.py +++ b/test_app/tests/test_apps.py @@ -1,8 +1,9 @@ -from unittest.mock import patch +from unittest.mock import patch, create_autospec from django.test import TestCase -from django_structlog import apps +from django_structlog import apps, commands +from django_structlog.celery import receivers class TestAppConfig(TestCase): @@ -10,20 +11,62 @@ def test_celery_enabled(self): app = apps.DjangoStructLogConfig( "django_structlog", __import__("django_structlog") ) + mock_receiver = create_autospec(spec=receivers.CeleryReceiver) with patch( - "django_structlog.celery.receivers.connect_celery_signals" - ) as mock_connect_celery_signals: + "django_structlog.celery.receivers.CeleryReceiver", + return_value=mock_receiver, + ): with self.settings(DJANGO_STRUCTLOG_CELERY_ENABLED=True): app.ready() - mock_connect_celery_signals.assert_called_once() + mock_receiver.connect_signals.assert_called_once() + + self.assertTrue(hasattr(app, "_celery_receiver")) + self.assertIsNotNone(app._celery_receiver) def test_celery_disabled(self): app = apps.DjangoStructLogConfig( "django_structlog", __import__("django_structlog") ) + + mock_receiver = create_autospec(spec=receivers.CeleryReceiver) with patch( - "django_structlog.celery.receivers.connect_celery_signals" - ) as mock_connect_celery_signals: + "django_structlog.celery.receivers.CeleryReceiver", + return_value=mock_receiver, + ): with self.settings(DJANGO_STRUCTLOG_CELERY_ENABLED=False): app.ready() - mock_connect_celery_signals.assert_not_called() + mock_receiver.connect_signals.assert_not_called() + + self.assertFalse(hasattr(app, "_celery_receiver")) + + def test_command_enabled(self): + app = apps.DjangoStructLogConfig( + "django_structlog", __import__("django_structlog") + ) + mock_receiver = create_autospec(spec=commands.DjangoCommandReceiver) + with patch( + "django_structlog.commands.DjangoCommandReceiver", + return_value=mock_receiver, + ): + with self.settings(DJANGO_STRUCTLOG_COMMAND_LOGGING_ENABLED=True): + app.ready() + mock_receiver.connect_signals.assert_called_once() + + self.assertTrue(hasattr(app, "_django_command_receiver")) + self.assertIsNotNone(app._django_command_receiver) + + def test_command_disabled(self): + app = apps.DjangoStructLogConfig( + "django_structlog", __import__("django_structlog") + ) + + mock_receiver = create_autospec(spec=commands.DjangoCommandReceiver) + with patch( + "django_structlog.commands.DjangoCommandReceiver", + return_value=mock_receiver, + ): + with self.settings(DJANGO_STRUCTLOG_COMMAND_LOGGING_ENABLED=False): + app.ready() + mock_receiver.connect_signals.assert_not_called() + + self.assertFalse(hasattr(app, "_django_command_receiver")) From a59a90ebdfc9d642fe9eef7dbd11fd8b95f44b28 Mon Sep 17 00:00:00 2001 From: Jules Robichaud-Gagnon Date: Tue, 17 Oct 2023 20:34:08 -0400 Subject: [PATCH 3/4] Fix name of test --- test_app/tests/celery/test_receivers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test_app/tests/celery/test_receivers.py b/test_app/tests/celery/test_receivers.py index 69b528b1..da84d9f0 100644 --- a/test_app/tests/celery/test_receivers.py +++ b/test_app/tests/celery/test_receivers.py @@ -168,7 +168,7 @@ def test_receiver_after_task_publish(self): self.assertIn("child_task_name", record.msg) self.assertEqual(expected_task_name, record.msg["child_task_name"]) - def test_receiver_after_task_publish_celery_3(self): + def test_receiver_after_task_publish_protocol_v1(self): expected_task_id = "00000000-0000-0000-0000-000000000000" expected_task_name = "Foo" body = {"id": expected_task_id, "task": expected_task_name} From db792828c30fac944e874e3dca06b319d11e09df Mon Sep 17 00:00:00 2001 From: Jules Robichaud-Gagnon Date: Wed, 18 Oct 2023 19:54:07 -0400 Subject: [PATCH 4/4] Add priority to task_enqueued --- django_structlog/celery/receivers.py | 14 ++++- django_structlog/celery/signals.py | 4 +- django_structlog_demo_project/home/views.py | 2 +- docs/celery.rst | 2 +- docs/events.rst | 4 ++ test_app/tests/celery/test_receivers.py | 59 +++++++++++++++++++++ 6 files changed, 81 insertions(+), 4 deletions(-) diff --git a/django_structlog/celery/receivers.py b/django_structlog/celery/receivers.py index 706177b8..43ad50c3 100644 --- a/django_structlog/celery/receivers.py +++ b/django_structlog/celery/receivers.py @@ -7,6 +7,9 @@ class CeleryReceiver: + def __init__(self): + self._priority = None + def receiver_before_task_publish( self, sender=None, @@ -31,16 +34,25 @@ def receiver_before_task_publish( task_routing_key=routing_key, task_properties=properties, ) + if properties: + self._priority = properties.get("priority", None) headers["__django_structlog__"] = context def receiver_after_task_publish( - self, sender=None, headers=None, body=None, **kwargs + self, sender=None, headers=None, body=None, routing_key=None, **kwargs ): + properties = {} + if self._priority is not None: + properties["priority"] = self._priority + self._priority = None + logger.info( "task_enqueued", child_task_id=headers.get("id") if headers else body.get("id"), child_task_name=headers.get("task") if headers else body.get("task"), + routing_key=routing_key, + **properties, ) def receiver_task_prerun(self, task_id, task, *args, **kwargs): diff --git a/django_structlog/celery/signals.py b/django_structlog/celery/signals.py index 5374f413..e647924f 100644 --- a/django_structlog/celery/signals.py +++ b/django_structlog/celery/signals.py @@ -22,12 +22,14 @@ """ Signal to modify context passed over to ``celery`` task's context. You must modify the ``context`` dict. :param context: the context dict that will be passed over to the task runner's logger +:param task_routing_key: routing key of the task +:param task_properties: task's message properties >>> from django.dispatch import receiver >>> from django_structlog.celery import signals >>> >>> @receiver(signals.modify_context_before_task_publish) -... def receiver_modify_context_before_task_publish(sender, signal, context, **kwargs): +... def receiver_modify_context_before_task_publish(sender, signal, context, task_routing_key=None, task_properties=None, **kwargs): ... keys_to_keep = {"request_id", "parent_task_id"} ... new_dict = { ... key_to_keep: context[key_to_keep] diff --git a/django_structlog_demo_project/home/views.py b/django_structlog_demo_project/home/views.py index 448aabdc..6ba3093e 100644 --- a/django_structlog_demo_project/home/views.py +++ b/django_structlog_demo_project/home/views.py @@ -15,7 +15,7 @@ def enqueue_successful_task(request): logger.info("Enqueuing successful task") - successful_task.delay(foo="bar") + successful_task.apply_async(foo="bar", priority=5) return HttpResponse(status=201) diff --git a/docs/celery.rst b/docs/celery.rst index cd94a4fe..8b08fa5f 100644 --- a/docs/celery.rst +++ b/docs/celery.rst @@ -157,7 +157,7 @@ By example you can strip down the ``context`` to keep only some of the keys: .. code-block:: python @receiver(signals.modify_context_before_task_publish) - def receiver_modify_context_before_task_publish(sender, signal, context, **kwargs): + def receiver_modify_context_before_task_publish(sender, signal, context, task_routing_key=None, task_properties=None, **kwargs): keys_to_keep = {"request_id", "parent_task_id"} new_dict = {key_to_keep: context[key_to_keep] for key_to_keep in keys_to_keep if key_to_keep in context} context.clear() diff --git a/docs/events.rst b/docs/events.rst index 29093492..7ee08ff8 100644 --- a/docs/events.rst +++ b/docs/events.rst @@ -122,6 +122,10 @@ These metadata appear once along with their associated event +------------------+------------------+----------------------------------------+ | task_enqueued | child_task_name | name of the task being enqueued | +------------------+------------------+----------------------------------------+ +| task_enqueued | routing_key | task's routing key | ++------------------+------------------+----------------------------------------+ +| task_enqueued | priority | priority of task (if any) | ++------------------+------------------+----------------------------------------+ | task_retrying | reason | reason for retry | +------------------+------------------+----------------------------------------+ | task_started | task | name of the task | diff --git a/test_app/tests/celery/test_receivers.py b/test_app/tests/celery/test_receivers.py index da84d9f0..994513c9 100644 --- a/test_app/tests/celery/test_receivers.py +++ b/test_app/tests/celery/test_receivers.py @@ -447,6 +447,65 @@ def test_receiver_task_rejected(self): self.assertIn("task_id", record.msg) self.assertEqual(task_id, record.msg["task_id"]) + def test_priority(self): + expected_uuid = "00000000-0000-0000-0000-000000000000" + user_id = "1234" + expected_parent_task_uuid = "11111111-1111-1111-1111-111111111111" + expected_routing_key = "foo" + expected_priority = 6 + properties = {"priority": expected_priority} + + headers = {} + structlog.contextvars.bind_contextvars( + request_id=expected_uuid, + user_id=user_id, + task_id=expected_parent_task_uuid, + ) + receiver = receivers.CeleryReceiver() + receiver.receiver_before_task_publish( + headers=headers, + routing_key=expected_routing_key, + properties=properties, + ) + + self.assertDictEqual( + { + "__django_structlog__": { + "user_id": user_id, + "request_id": expected_uuid, + "parent_task_id": expected_parent_task_uuid, + } + }, + headers, + "Only `request_id` and `parent_task_id` are preserved", + ) + + expected_task_id = "00000000-0000-0000-0000-000000000000" + expected_task_name = "Foo" + headers = {"id": expected_task_id, "task": expected_task_name} + + with self.assertLogs( + logging.getLogger("django_structlog.celery.receivers"), logging.INFO + ) as log_results: + receiver.receiver_after_task_publish( + headers=headers, routing_key=expected_routing_key + ) + + self.assertEqual(1, len(log_results.records)) + record = log_results.records[0] + self.assertEqual("task_enqueued", record.msg["event"]) + self.assertEqual("INFO", record.levelname) + self.assertIn("child_task_id", record.msg) + self.assertEqual(expected_task_id, record.msg["child_task_id"]) + self.assertIn("child_task_name", record.msg) + self.assertEqual(expected_task_name, record.msg["child_task_name"]) + + self.assertIn("priority", record.msg) + self.assertEqual(expected_priority, record.msg["priority"]) + + self.assertIn("routing_key", record.msg) + self.assertEqual(expected_routing_key, record.msg["routing_key"]) + class TestConnectCeleryTaskSignals(TestCase): def test_call(self):