Skip to content

Commit

Permalink
fix baggage, add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dimastbk committed Sep 17, 2024
1 parent c048b09 commit 34b76cc
Showing 1 changed file with 90 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import uuid
from typing import List, Tuple
from typing import List, Sequence, Tuple
from unittest import IsolatedAsyncioTestCase, mock

from aiokafka import (
Expand All @@ -24,12 +24,13 @@
)
from wrapt import BoundFunctionWrapper

from opentelemetry import baggage, context
from opentelemetry.instrumentation.aiokafka import AIOKafkaInstrumentor
from opentelemetry.sdk.trace import Span
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.semconv._incubating.attributes import messaging_attributes
from opentelemetry.semconv.attributes import server_attributes
from opentelemetry.test.test_base import TestBase
from opentelemetry.trace import SpanKind, format_trace_id
from opentelemetry.trace import SpanKind, format_trace_id, set_span_in_context


class TestAIOKafka(TestBase, IsolatedAsyncioTestCase):
Expand All @@ -51,6 +52,19 @@ def consumer_record_factory(
headers=headers,
)

@staticmethod
def producer_factory() -> AIOKafkaProducer:
producer = AIOKafkaProducer(api_version="1.0")

add_message_mock = mock.AsyncMock()
producer.client._wait_on_metadata = mock.AsyncMock()
producer.client.bootstrap = mock.AsyncMock()
producer._message_accumulator.add_message = add_message_mock
producer._sender.start = mock.AsyncMock()
producer._partition = mock.Mock(return_value=1)

return producer

def test_instrument_api(self) -> None:
instrumentation = AIOKafkaInstrumentor()

Expand Down Expand Up @@ -147,7 +161,46 @@ async def test_anext(self) -> None:
span_list = self.memory_exporter.get_finished_spans()
self._compare_spans(span_list, expected_spans)

async def test_anext_consumer_hook(self) -> None:
async def test_anext_baggage(self) -> None:
received_baggage = None

async def async_consume_hook(span, *_) -> None:
nonlocal received_baggage
received_baggage = baggage.get_all(set_span_in_context(span))

AIOKafkaInstrumentor().uninstrument()
AIOKafkaInstrumentor().instrument(
tracer_provider=self.tracer_provider,
async_consume_hook=async_consume_hook,
)

consumer = AIOKafkaConsumer()

self.memory_exporter.clear()

getone_mock = mock.AsyncMock()
consumer.getone = getone_mock

getone_mock.side_effect = [
self.consumer_record_factory(
1,
headers=(
(
"traceparent",
b"00-03afa25236b8cd948fa853d67038ac79-405ff022e8247c46-01",
),
("baggage", b"foo=bar"),
),
),
self.consumer_record_factory(2, headers=()),
]

await consumer.__anext__()
getone_mock.assert_awaited_with()

self.assertEqual(received_baggage, {"foo": "bar"})

async def test_anext_consume_hook(self) -> None:
async_consume_hook_mock = mock.AsyncMock()

AIOKafkaInstrumentor().uninstrument()
Expand All @@ -171,14 +224,10 @@ async def test_send(self) -> None:
AIOKafkaInstrumentor().uninstrument()
AIOKafkaInstrumentor().instrument(tracer_provider=self.tracer_provider)

producer = AIOKafkaProducer(api_version="1.0")

add_message_mock = mock.AsyncMock()
producer.client._wait_on_metadata = mock.AsyncMock()
producer.client.bootstrap = mock.AsyncMock()
producer._message_accumulator.add_message = add_message_mock
producer._sender.start = mock.AsyncMock()
producer._partition = mock.Mock(return_value=1)
producer = self.producer_factory()
add_message_mock: mock.AsyncMock = (
producer._message_accumulator.add_message
)

await producer.start()

Expand Down Expand Up @@ -208,6 +257,33 @@ async def test_send(self) -> None:
headers=[("traceparent", mock.ANY)],
)

async def test_send_baggage(self) -> None:
AIOKafkaInstrumentor().uninstrument()
AIOKafkaInstrumentor().instrument(tracer_provider=self.tracer_provider)

producer = self.producer_factory()
add_message_mock: mock.AsyncMock = (
producer._message_accumulator.add_message
)

await producer.start()

tracer = self.tracer_provider.get_tracer(__name__)
ctx = baggage.set_baggage("foo", "bar")
context.attach(ctx)

with tracer.start_as_current_span("test_span", context=ctx):
await producer.send("topic_1", b"value_1")

add_message_mock.assert_awaited_with(
TopicPartition(topic="topic_1", partition=1),
None,
b"value_1",
40.0,
timestamp_ms=None,
headers=[("traceparent", mock.ANY), ("baggage", b"foo=bar")],
)

async def test_send_produce_hook(self) -> None:
async_produce_hook_mock = mock.AsyncMock()

Expand All @@ -217,13 +293,7 @@ async def test_send_produce_hook(self) -> None:
async_produce_hook=async_produce_hook_mock,
)

producer = AIOKafkaProducer(api_version="1.0")

producer.client._wait_on_metadata = mock.AsyncMock()
producer.client.bootstrap = mock.AsyncMock()
producer._message_accumulator.add_message = mock.AsyncMock()
producer._sender.start = mock.AsyncMock()
producer._partition = mock.Mock(return_value=1)
producer = self.producer_factory()

await producer.start()

Expand All @@ -232,7 +302,7 @@ async def test_send_produce_hook(self) -> None:
async_produce_hook_mock.assert_awaited_once()

def _compare_spans(
self, spans: List[Span], expected_spans: List[dict]
self, spans: Sequence[ReadableSpan], expected_spans: List[dict]
) -> None:
self.assertEqual(len(spans), len(expected_spans))
for span, expected_span in zip(spans, expected_spans):
Expand Down

0 comments on commit 34b76cc

Please sign in to comment.