Skip to content

Commit

Permalink
Added context propagation support to celery instrumentation
Browse files Browse the repository at this point in the history
  • Loading branch information
owais committed Sep 17, 2020
1 parent b923c52 commit 1f72b1e
Show file tree
Hide file tree
Showing 12 changed files with 227 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

## Unreleased

- Span operation names now include the task type.
- Added automatic context propagation.

## Version 0.12b0

Released 2020-08-14
Expand Down
22 changes: 20 additions & 2 deletions instrumentation/opentelemetry-instrumentation-celery/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,20 @@ Usage

.. code-block:: python
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchExportSpanProcessor
from opentelemetry.instrumentation.celery import CeleryInstrumentor
CeleryInstrumentor().instrument()
from celery import Celery
from celery.signals import worker_process_init
@worker_process_init.connect(weak=False)
def init_celery_tracing(*args, **kwargs):
trace.set_tracer_provider(TracerProvider())
span_processor = BatchExportSpanProcessor(ConsoleSpanExporter())
trace.get_tracer_provider().add_span_processor(span_processor)
CeleryInstrumentor().instrument()
app = Celery("tasks", broker="amqp://localhost")
Expand All @@ -43,6 +52,15 @@ Usage
add.delay(42, 50)
Setting up tracing
--------------------

When tracing a celery worker process, tracing and instrumention both must be initialized after the celery worker
process is initialized. This is required for any tracing components that might use threading to work correctly
such as the BatchExportSpanProcessor. Celery provides a signal called ``worker_process_init`` that can be used to
accomplish this as shown in the example above.

References
----------
* `OpenTelemetry Celery Instrumentation <https://opentelemetry-python.readthedocs.io/en/latest/ext/celery/celery.html>`_
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ install_requires =
[options.extras_require]
test =
pytest
celery ~= 4.0
opentelemetry-test == 0.14.dev0

[options.packages.find]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,20 @@
.. code:: python
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchExportSpanProcessor
from opentelemetry.instrumentation.celery import CeleryInstrumentor
CeleryInstrumentor().instrument()
from celery import Celery
from celery.signals import worker_process_init
@worker_process_init.connect(weak=False)
def init_celery_tracing(*args, **kwargs):
trace.set_tracer_provider(TracerProvider())
span_processor = BatchExportSpanProcessor(ConsoleSpanExporter())
trace.get_tracer_provider().add_span_processor(span_processor)
CeleryInstrumentor().instrument()
app = Celery("tasks", broker="amqp://localhost")
Expand All @@ -50,13 +59,15 @@ def add(x, y):

import logging
import signal
from collections.abc import Iterable

from celery import signals # pylint: disable=no-name-in-module

from opentelemetry import trace
from opentelemetry import propagators, trace
from opentelemetry.instrumentation.celery import utils
from opentelemetry.instrumentation.celery.version import __version__
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.trace.propagation import get_current_span
from opentelemetry.trace.status import Status, StatusCanonicalCode

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -106,9 +117,16 @@ def _trace_prerun(self, *args, **kwargs):
if task is None or task_id is None:
return

request = task.request
tracectx = propagators.extract(carrier_extractor, request) or {}
parent = get_current_span(tracectx)

logger.debug("prerun signal start task_id=%s", task_id)

span = self._tracer.start_span(task.name, kind=trace.SpanKind.CONSUMER)
operation_name = "{0}/{1}".format(_TASK_RUN, task.name)
span = self._tracer.start_span(
operation_name, parent=parent, kind=trace.SpanKind.CONSUMER
)

activation = self._tracer.use_span(span, end_on_exit=True)
activation.__enter__()
Expand Down Expand Up @@ -146,7 +164,10 @@ def _trace_before_publish(self, *args, **kwargs):
if task is None or task_id is None:
return

span = self._tracer.start_span(task.name, kind=trace.SpanKind.PRODUCER)
operation_name = "{0}/{1}".format(_TASK_APPLY_ASYNC, task.name)
span = self._tracer.start_span(
operation_name, kind=trace.SpanKind.PRODUCER
)

# apply some attributes here because most of the data is not available
span.set_attribute(_TASK_TAG_KEY, _TASK_APPLY_ASYNC)
Expand All @@ -158,6 +179,10 @@ def _trace_before_publish(self, *args, **kwargs):
activation.__enter__()
utils.attach_span(task, task_id, (span, activation), is_publish=True)

headers = kwargs.get("headers")
if headers:
propagators.inject(type(headers).__setitem__, headers)

@staticmethod
def _trace_after_publish(*args, **kwargs):
task = utils.retrieve_task_from_sender(kwargs)
Expand Down Expand Up @@ -221,3 +246,10 @@ def _trace_retry(*args, **kwargs):
# Use `str(reason)` instead of `reason.message` in case we get
# something that isn't an `Exception`
span.set_attribute(_TASK_RETRY_REASON_KEY, str(reason))


def carrier_extractor(carrier, key):
value = getattr(carrier, key, [])
if isinstance(value, str) or not isinstance(value, Iterable):
value = (value,)
return value
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright The OpenTelemetry Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from celery import Celery


class Config:
result_backend = "rpc"
broker_backend = "memory"


app = Celery(broker="memory:///")
app.config_from_object(Config)


@app.task
def task_add(num_a, num_b):
return num_a + num_b
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright The OpenTelemetry Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import threading
import time

from opentelemetry.instrumentation.celery import CeleryInstrumentor
from opentelemetry.test.test_base import TestBase
from opentelemetry.trace import SpanKind

from .celery_test_tasks import app, task_add


class TestCeleryInstrumentation(TestBase):
def setUp(self):
super().setUp()
self._worker = app.Worker(app=app, pool="solo", concurrency=1)
self._thread = threading.Thread(target=self._worker.start)
self._thread.daemon = True
self._thread.start()

def tearDown(self):
super().tearDown()
self._worker.stop()
self._thread.join()

def test_task(self):
CeleryInstrumentor().instrument()

result = task_add.delay(1, 2)
while not result.ready():
time.sleep(0.05)

spans = self.sorted_spans(self.memory_exporter.get_finished_spans())
self.assertEqual(len(spans), 2)

consumer, producer = spans

self.assertEqual(consumer.name, "run/tests.celery_test_tasks.task_add")
self.assertEqual(consumer.kind, SpanKind.CONSUMER)
self.assert_span_has_attributes(
consumer,
{
"celery.action": "run",
"celery.state": "SUCCESS",
"messaging.destination": "celery",
"celery.task_name": "tests.celery_test_tasks.task_add",
},
)

self.assertEqual(
producer.name, "apply_async/tests.celery_test_tasks.task_add"
)
self.assertEqual(producer.kind, SpanKind.PRODUCER)
self.assert_span_has_attributes(
producer,
{
"celery.action": "apply_async",
"celery.task_name": "tests.celery_test_tasks.task_add",
"messaging.destination_kind": "queue",
"messaging.destination": "celery",
},
)

self.assertNotEqual(consumer.parent, producer.context)
self.assertEqual(consumer.parent.span_id, producer.context.span_id)
self.assertEqual(consumer.context.trace_id, producer.context.trace_id)
Loading

0 comments on commit 1f72b1e

Please sign in to comment.