diff --git a/sentry_sdk/integrations/celery/__init__.py b/sentry_sdk/integrations/celery/__init__.py index 5b8a90fdb9..88a2119c09 100644 --- a/sentry_sdk/integrations/celery/__init__.py +++ b/sentry_sdk/integrations/celery/__init__.py @@ -41,6 +41,7 @@ try: from celery import VERSION as CELERY_VERSION # type: ignore + from celery.app.task import Task # type: ignore from celery.app.trace import task_has_custom from celery.exceptions import ( # type: ignore Ignore, @@ -83,6 +84,7 @@ def setup_once(): _patch_build_tracer() _patch_task_apply_async() + _patch_celery_send_task() _patch_worker_exit() _patch_producer_publish() @@ -243,7 +245,7 @@ def __exit__(self, exc_type, exc_value, traceback): return None -def _wrap_apply_async(f): +def _wrap_task_run(f): # type: (F) -> F @wraps(f) @ensure_integration_enabled(CeleryIntegration, f) @@ -260,14 +262,19 @@ def apply_async(*args, **kwargs): if not propagate_traces: return f(*args, **kwargs) - task = args[0] + if isinstance(args[0], Task): + task_name = args[0].name # type: str + elif len(args) > 1 and isinstance(args[1], str): + task_name = args[1] + else: + task_name = "" task_started_from_beat = sentry_sdk.get_isolation_scope()._name == "celery-beat" span_mgr = ( sentry_sdk.start_span( op=OP.QUEUE_SUBMIT_CELERY, - description=task.name, + description=task_name, origin=CeleryIntegration.origin, ) if not task_started_from_beat @@ -437,9 +444,14 @@ def sentry_build_tracer(name, task, *args, **kwargs): def _patch_task_apply_async(): # type: () -> None - from celery.app.task import Task # type: ignore + Task.apply_async = _wrap_task_run(Task.apply_async) + + +def _patch_celery_send_task(): + # type: () -> None + from celery import Celery - Task.apply_async = _wrap_apply_async(Task.apply_async) + Celery.send_task = _wrap_task_run(Celery.send_task) def _patch_worker_exit(): diff --git a/tests/integrations/celery/test_celery.py b/tests/integrations/celery/test_celery.py index cc0bfd0390..ffd3f0db62 100644 --- a/tests/integrations/celery/test_celery.py +++ b/tests/integrations/celery/test_celery.py @@ -10,7 +10,7 @@ from sentry_sdk import start_transaction, get_current_span from sentry_sdk.integrations.celery import ( CeleryIntegration, - _wrap_apply_async, + _wrap_task_run, ) from sentry_sdk.integrations.celery.beat import _get_headers from tests.conftest import ApproxDict @@ -568,7 +568,7 @@ def dummy_function(*args, **kwargs): assert "sentry-trace" in headers assert "baggage" in headers - wrapped = _wrap_apply_async(dummy_function) + wrapped = _wrap_task_run(dummy_function) wrapped(mock.MagicMock(), (), headers={}) @@ -783,3 +783,51 @@ def task(): ... assert span["origin"] == "auto.queue.celery" monkeypatch.setattr(kombu.messaging.Producer, "_publish", old_publish) + + +@pytest.mark.forked +@mock.patch("celery.Celery.send_task") +def test_send_task_wrapped( + patched_send_task, + sentry_init, + capture_events, + reset_integrations, +): + sentry_init(integrations=[CeleryIntegration()], enable_tracing=True) + celery = Celery(__name__, broker="redis://example.com") # noqa: E231 + + events = capture_events() + + with sentry_sdk.start_transaction(name="custom_transaction"): + celery.send_task("very_creative_task_name", args=(1, 2), kwargs={"foo": "bar"}) + + (call,) = patched_send_task.call_args_list # We should have exactly one call + (args, kwargs) = call + + assert args == (celery, "very_creative_task_name") + assert kwargs["args"] == (1, 2) + assert kwargs["kwargs"] == {"foo": "bar"} + assert set(kwargs["headers"].keys()) == { + "sentry-task-enqueued-time", + "sentry-trace", + "baggage", + "headers", + } + assert set(kwargs["headers"]["headers"].keys()) == { + "sentry-trace", + "baggage", + "sentry-task-enqueued-time", + } + assert ( + kwargs["headers"]["sentry-trace"] + == kwargs["headers"]["headers"]["sentry-trace"] + ) + + (event,) = events # We should have exactly one event (the transaction) + assert event["type"] == "transaction" + assert event["transaction"] == "custom_transaction" + + (span,) = event["spans"] # We should have exactly one span + assert span["description"] == "very_creative_task_name" + assert span["op"] == "queue.submit.celery" + assert span["trace_id"] == kwargs["headers"]["sentry-trace"].split("-")[0] diff --git a/tox.ini b/tox.ini index fcab3ad1ed..dd1dbf1156 100644 --- a/tox.ini +++ b/tox.ini @@ -371,8 +371,9 @@ deps = celery-v5.4: Celery~=5.4.0 celery-latest: Celery - {py3.7}-celery: importlib-metadata<5.0 {py3.6,py3.7,py3.8,py3.9,py3.10,py3.11,py3.12}-celery: newrelic + celery: pytest<7 + {py3.7}-celery: importlib-metadata<5.0 # Chalice chalice-v1.16: chalice~=1.16.0