Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PI-2342 Propagate trace context over SNS for distributed tracing #3995

Merged
merged 1 commit into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package uk.gov.justice.digital.hmpps.telemetry

import com.microsoft.applicationinsights.TelemetryClient
import com.microsoft.applicationinsights.telemetry.TelemetryContext
import org.slf4j.LoggerFactory
import org.springframework.scheduling.annotation.Async
import org.springframework.stereotype.Service
import java.lang.Exception

@Service
class TelemetryService(private val telemetryClient: TelemetryClient = TelemetryClient()) {
Expand All @@ -27,4 +27,6 @@ class TelemetryService(private val telemetryClient: TelemetryClient = TelemetryC
log.debug("{} {} {}", exception.message, properties, metrics)
telemetryClient.trackException(exception, properties, metrics)
}

fun getContext(): TelemetryContext = telemetryClient.context
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
package uk.gov.justice.digital.hmpps.listener

import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.module.kotlin.jacksonTypeRef
import io.awspring.cloud.sqs.annotation.SqsListener
import io.awspring.cloud.sqs.listener.AsyncAdapterBlockingExecutionFailedException
import io.awspring.cloud.sqs.listener.ListenerExecutionFailedException
import io.opentelemetry.api.trace.SpanKind
import io.opentelemetry.instrumentation.annotations.SpanAttribute
import io.opentelemetry.instrumentation.annotations.WithSpan
import io.sentry.Sentry
import io.sentry.spring.jakarta.tracing.SentryTransaction
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import org.springframework.boot.autoconfigure.condition.ConditionalOnExpression
import org.springframework.context.annotation.Conditional
import org.springframework.dao.CannotAcquireLockException
Expand All @@ -17,36 +22,35 @@ import org.springframework.transaction.CannotCreateTransactionException
import org.springframework.transaction.UnexpectedRollbackException
import org.springframework.web.client.RestClientException
import uk.gov.justice.digital.hmpps.config.AwsCondition
import uk.gov.justice.digital.hmpps.message.Notification
import uk.gov.justice.digital.hmpps.messaging.NotificationHandler
import uk.gov.justice.digital.hmpps.retry.retry
import uk.gov.justice.digital.hmpps.telemetry.TelemetryMessagingExtensions.extractSpanContext
import uk.gov.justice.digital.hmpps.telemetry.TelemetryMessagingExtensions.startSpan
import java.util.concurrent.CompletionException

@Component
@Conditional(AwsCondition::class)
@ConditionalOnExpression("\${messaging.consumer.enabled:true} and '\${messaging.consumer.queue:}' != ''")
class AwsNotificationListener(
private val handler: NotificationHandler<*>
private val handler: NotificationHandler<*>,
private val objectMapper: ObjectMapper
) {
@SqsListener("\${messaging.consumer.queue}")
@SentryTransaction(operation = "messaging")
@WithSpan(kind = SpanKind.CONSUMER)
fun receive(message: String) {
try {
retry(
3,
listOf(
RestClientException::class,
CannotAcquireLockException::class,
ObjectOptimisticLockingFailureException::class,
CannotCreateTransactionException::class,
CannotGetJdbcConnectionException::class,
UnexpectedRollbackException::class
)
) { handler.handle(message) }
} catch (e: Throwable) {
Sentry.captureException(unwrapSqsExceptions(e))
throw e
@SentryTransaction(operation = "messaging")
@SqsListener("\${messaging.consumer.queue}")
fun receive(@SpanAttribute message: String) {
val attributes = objectMapper.readValue(message, jacksonTypeRef<Notification<String>>()).attributes
val span = attributes.extractSpanContext().startSpan(this::class.java.name, "receive", SpanKind.CONSUMER)
span.makeCurrent().use {
try {
retry(3, RETRYABLE_EXCEPTIONS) { handler.handle(message) }
} catch (e: Throwable) {
Sentry.captureException(unwrapSqsExceptions(e))
throw e
}
}
span.end()
}

fun unwrapSqsExceptions(e: Throwable): Throwable {
Expand All @@ -63,4 +67,16 @@ class AwsNotificationListener(
}
return cause
}

companion object {
private val log: Logger = LoggerFactory.getLogger(this::class.java)
val RETRYABLE_EXCEPTIONS = listOf(
RestClientException::class,
CannotAcquireLockException::class,
ObjectOptimisticLockingFailureException::class,
CannotCreateTransactionException::class,
CannotGetJdbcConnectionException::class,
UnexpectedRollbackException::class
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import com.fasterxml.jackson.annotation.JsonAnyGetter
import com.fasterxml.jackson.annotation.JsonAnySetter
import com.fasterxml.jackson.annotation.JsonIgnore
import com.fasterxml.jackson.annotation.JsonProperty
import java.util.UUID
import java.util.*

data class Notification<T>(
@JsonProperty("Message") val message: T,
Expand All @@ -21,9 +21,14 @@ data class MessageAttributes(
constructor(eventType: String) : this(mutableMapOf("eventType" to MessageAttribute("String", eventType)))

override operator fun get(key: String): MessageAttribute? = attributes[key]

operator fun set(key: String, value: MessageAttribute) {
attributes[key] = value
}

operator fun set(key: String, value: String) {
set(key, MessageAttribute("String", value))
}
}

data class MessageAttribute(@JsonProperty("Type") val type: String, @JsonProperty("Value") val value: String)
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package uk.gov.justice.digital.hmpps.publisher

import com.fasterxml.jackson.databind.ObjectMapper
import io.awspring.cloud.sqs.operations.SqsTemplate
import io.opentelemetry.api.trace.SpanKind
import io.opentelemetry.instrumentation.annotations.WithSpan
import org.springframework.beans.factory.annotation.Value
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty
import org.springframework.context.annotation.Conditional
Expand All @@ -10,6 +12,7 @@ import org.springframework.messaging.support.MessageBuilder
import org.springframework.stereotype.Component
import uk.gov.justice.digital.hmpps.config.AwsCondition
import uk.gov.justice.digital.hmpps.message.Notification
import uk.gov.justice.digital.hmpps.telemetry.TelemetryMessagingExtensions.withSpanContext
import java.util.concurrent.Semaphore

@Component
Expand All @@ -23,6 +26,8 @@ class QueuePublisher(
) : NotificationPublisher {

private val permit = Semaphore(limit, true)

@WithSpan(kind = SpanKind.PRODUCER)
override fun publish(notification: Notification<*>) {
notification.message?.also { _ ->
permit.acquire()
Expand All @@ -35,12 +40,7 @@ class QueuePublisher(
}

private fun Notification<*>.asMessage() = MessageBuilder.createMessage(
objectMapper.writeValueAsString(
Notification(
message = objectMapper.writeValueAsString(message),
attributes
)
),
MessageHeaders(attributes.map { it.key to it.value.value }.toMap())
objectMapper.writeValueAsString(Notification(objectMapper.writeValueAsString(message), attributes)),
MessageHeaders(attributes.map { it.key to it.value.value }.toMap()).withSpanContext()
)
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package uk.gov.justice.digital.hmpps.publisher

import io.awspring.cloud.sns.core.SnsTemplate
import io.opentelemetry.api.trace.SpanKind
import io.opentelemetry.instrumentation.annotations.WithSpan
import org.springframework.beans.factory.annotation.Value
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty
import org.springframework.context.annotation.Conditional
Expand All @@ -10,6 +12,7 @@ import org.springframework.messaging.support.MessageBuilder
import org.springframework.stereotype.Component
import uk.gov.justice.digital.hmpps.config.AwsCondition
import uk.gov.justice.digital.hmpps.message.Notification
import uk.gov.justice.digital.hmpps.telemetry.TelemetryMessagingExtensions.withSpanContext

@Primary
@Component
Expand All @@ -19,12 +22,13 @@ class TopicPublisher(
private val notificationTemplate: SnsTemplate,
@Value("\${messaging.producer.topic}") private val topic: String
) : NotificationPublisher {
@WithSpan(kind = SpanKind.PRODUCER)
override fun publish(notification: Notification<*>) {
notification.message?.let { message ->
notificationTemplate.convertAndSend(topic, message) { msg ->
MessageBuilder.createMessage(
msg.payload,
MessageHeaders(notification.attributes.map { it.key to it.value.value }.toMap())
MessageHeaders(notification.attributes.map { it.key to it.value.value }.toMap()).withSpanContext(),
)
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package uk.gov.justice.digital.hmpps.telemetry

import io.opentelemetry.api.GlobalOpenTelemetry
import io.opentelemetry.api.trace.Span
import io.opentelemetry.api.trace.SpanKind
import io.opentelemetry.context.Context
import io.opentelemetry.context.propagation.TextMapGetter
import org.springframework.messaging.MessageHeaders
import uk.gov.justice.digital.hmpps.message.HmppsDomainEvent
import uk.gov.justice.digital.hmpps.message.MessageAttributes
import uk.gov.justice.digital.hmpps.message.Notification

object TelemetryMessagingExtensions {
fun MessageHeaders.withSpanContext(): MessageHeaders {
val map = this.toMutableMap()
val context = Context.current().with(Span.current())
GlobalOpenTelemetry.getPropagators().textMapPropagator
.inject(context, map) { carrier, key, value -> carrier!![key] = value }
return MessageHeaders(map)
}

fun MessageAttributes.extractSpanContext(): Context {
val getter = object : TextMapGetter<MessageAttributes> {
override fun keys(carrier: MessageAttributes) = carrier.keys
override fun get(carrier: MessageAttributes?, key: String) = carrier?.get(key)?.value
}
return GlobalOpenTelemetry.getPropagators().textMapPropagator.extract(Context.current(), this, getter)
}

fun Context.startSpan(scopeName: String, spanName: String, spanKind: SpanKind = SpanKind.INTERNAL): Span {
val tracer = GlobalOpenTelemetry.getTracer(scopeName)
return tracer.spanBuilder(spanName).setParent(this).setSpanKind(spanKind).startSpan()
}

fun TelemetryService.hmppsEventReceived(hmppsEvent: HmppsDomainEvent) {
trackEvent(
"NotificationReceived",
mapOf("eventType" to hmppsEvent.eventType) +
(hmppsEvent.detailUrl?.let { mapOf("detailUrl" to it) } ?: mapOf()) +
(hmppsEvent.personReference.identifiers.associate { Pair(it.type, it.value) })
)
}

fun <T> TelemetryService.notificationReceived(notification: Notification<T>) {
if (notification.message is HmppsDomainEvent) {
hmppsEventReceived(notification.message)
} else {
trackEvent("NotificationReceived", notification.eventType?.let { mapOf("eventType" to it) } ?: mapOf())
}
}
}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
package uk.gov.justice.digital.hmpps.listener

import com.fasterxml.jackson.core.type.TypeReference
import com.fasterxml.jackson.databind.ObjectMapper
import io.awspring.cloud.sqs.listener.AsyncAdapterBlockingExecutionFailedException
import io.awspring.cloud.sqs.listener.ListenerExecutionFailedException
import io.sentry.Sentry
import org.hamcrest.CoreMatchers.equalTo
import org.hamcrest.MatcherAssert.assertThat
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertThrows
import org.junit.jupiter.api.extension.ExtendWith
import org.mockito.InjectMocks
import org.mockito.Mock
import org.mockito.Mockito.mockStatic
import org.mockito.junit.jupiter.MockitoExtension
import org.mockito.kotlin.any
import org.mockito.kotlin.verify
import org.mockito.kotlin.whenever
import org.springframework.messaging.support.GenericMessage
import uk.gov.justice.digital.hmpps.message.Notification
import uk.gov.justice.digital.hmpps.messaging.NotificationHandler
import java.util.concurrent.CompletionException

Expand All @@ -23,9 +28,18 @@ class AwsNotificationListenerTest {
@Mock
lateinit var handler: NotificationHandler<Any>

@Mock
lateinit var objectMapper: ObjectMapper

@InjectMocks
lateinit var listener: AwsNotificationListener

@BeforeEach
fun setUp() {
whenever(objectMapper.readValue(any<String>(), any<TypeReference<Notification<String>>>()))
.thenReturn(Notification("message"))
}

@Test
fun `messages are dispatched to handler`() {
listener.receive("message")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@ package uk.gov.justice.digital.hmpps.telemetry

import com.microsoft.applicationinsights.TelemetryClient
import org.hamcrest.MatcherAssert.assertThat
import org.hamcrest.Matchers.equalTo
import org.hamcrest.Matchers.hasProperty
import org.hamcrest.Matchers.not
import org.hamcrest.Matchers.*
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.extension.ExtendWith
Expand All @@ -14,15 +12,13 @@ import org.mockito.junit.jupiter.MockitoExtension
import org.mockito.kotlin.check
import org.mockito.kotlin.eq
import org.mockito.kotlin.verify
import uk.gov.justice.digital.hmpps.message.HmppsDomainEvent
import uk.gov.justice.digital.hmpps.message.MessageAttributes
import uk.gov.justice.digital.hmpps.message.Notification
import uk.gov.justice.digital.hmpps.message.PersonIdentifier
import uk.gov.justice.digital.hmpps.message.PersonReference
import uk.gov.justice.digital.hmpps.message.*
import uk.gov.justice.digital.hmpps.telemetry.TelemetryMessagingExtensions.hmppsEventReceived
import uk.gov.justice.digital.hmpps.telemetry.TelemetryMessagingExtensions.notificationReceived
import java.time.ZonedDateTime

@ExtendWith(MockitoExtension::class)
class TelemetryServiceTest {
class TelemetryMessagingExtensionsTest {

@Mock
private lateinit var telemetryClient: TelemetryClient
Expand Down Expand Up @@ -52,7 +48,7 @@ class TelemetryServiceTest {
)

verify(telemetryClient).trackEvent(
eq("SOME_SPECIAL_EVENT_RECEIVED"),
eq("NotificationReceived"),
check {
assertThat(it["eventType"], equalTo(eventType))
assertThat(it["detailUrl"], equalTo(detailUrl))
Expand All @@ -73,7 +69,7 @@ class TelemetryServiceTest {
)

verify(telemetryClient).trackEvent(
eq("SOME_SPECIAL_EVENT_RECEIVED"),
eq("NotificationReceived"),
check { assertThat(it, not(hasProperty("detailUrl"))) },
anyMap()
)
Expand All @@ -88,13 +84,13 @@ class TelemetryServiceTest {
)
)

verify(telemetryClient).trackEvent(eq("TEST_EVENT_RECEIVED"), anyMap(), anyMap())
verify(telemetryClient).trackEvent(eq("NotificationReceived"), anyMap(), anyMap())
}

@Test
fun `handles events with no event type`() {
telemetryService.notificationReceived(Notification(message = "this is a string"))

verify(telemetryClient).trackEvent(eq("UNKNOWN_EVENT_RECEIVED"), anyMap(), anyMap())
verify(telemetryClient).trackEvent(eq("NotificationReceived"), anyMap(), anyMap())
}
}
Loading