Skip to content

Commit

Permalink
Replicate parent and links on early spans
Browse files Browse the repository at this point in the history
  • Loading branch information
RafalSumislawski committed Jul 17, 2024
1 parent 97eb6cc commit dfb4aa3
Showing 1 changed file with 73 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,11 @@ def custom_event_context_extractor(lambda_event):
use_span,
)
from opentelemetry.trace.propagation import get_current_span
from opentelemetry.trace.span import INVALID_SPAN_ID
from opentelemetry.trace.span import (
INVALID_SPAN_ID,
format_trace_id,
format_span_id,
)
import json
import typing
import base64
Expand Down Expand Up @@ -174,12 +178,12 @@ def _default_event_context_extractor(args: Any) -> Context:
return get_global_textmap().extract(headers)


def _determine_parent_context(
def _determine_upstream_context(
lambda_event: Any,
event_context_extractor: Callable[[Any], Context],
disable_aws_context_propagation: bool = False,
) -> Context:
"""Determine the parent context for the current Lambda invocation.
"""Determine the upstream context for the current Lambda invocation.
See more:
https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/trace/semantic_conventions/instrumentation/aws-lambda.md#determining-the-parent-of-a-span
Expand All @@ -197,30 +201,30 @@ def _determine_parent_context(
Returns:
A Context with configuration found in the carrier.
"""
parent_context = None
upstream_context = None

if not disable_aws_context_propagation:
xray_env_var = os.environ.get(_X_AMZN_TRACE_ID)

if xray_env_var:
parent_context = AwsXRayPropagator().extract(
upstream_context = AwsXRayPropagator().extract(
{TRACE_HEADER_KEY: xray_env_var}
)

if (
parent_context
and get_current_span(parent_context)
upstream_context
and get_current_span(upstream_context)
.get_span_context()
.trace_flags.sampled
):
return parent_context
return upstream_context

if event_context_extractor:
parent_context = event_context_extractor(lambda_event)
upstream_context = event_context_extractor(lambda_event)
else:
parent_context = _default_event_context_extractor(lambda_event)
upstream_context = _default_event_context_extractor(lambda_event)

return parent_context
return upstream_context


def _set_api_gateway_v1_proxy_attributes(
Expand Down Expand Up @@ -346,7 +350,7 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches

lambda_event = args[0]

parent_context = _determine_parent_context(
upstream_context = _determine_upstream_context(
args,
event_context_extractor,
disable_aws_context_propagation,
Expand Down Expand Up @@ -374,6 +378,7 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches

tracer = get_tracer(__name__, __version__, tracer_provider)

trigger_context = None
triggerSpan = None

apiGwSpan = None
Expand All @@ -390,7 +395,7 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches
if lambda_event.get("requestContext") and lambda_event["requestContext"].get("http"):
span_name = lambda_event["requestContext"]["http"].get("path")

apiGwSpan = tracer.start_span(span_name, context=parent_context, kind=SpanKind.CLIENT)
apiGwSpan = tracer.start_span(span_name, context=upstream_context, kind=SpanKind.CLIENT)
if lambda_event.get("version") == "2.0":
apiGwSpan.set_attribute("faas.trigger.type", "Api Gateway Rest")
else:
Expand All @@ -399,7 +404,7 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches
apiGwSpan.set_attribute(SpanAttributes.FAAS_TRIGGER, "http")

triggerSpan = apiGwSpan
parent_context = set_span_in_context(apiGwSpan)
trigger_context = set_span_in_context(apiGwSpan)
except Exception as ex:
pass
# S3 trigger new span and request attributes
Expand All @@ -412,12 +417,12 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches
if lambda_event["Records"][0].get("eventName"):
span_name = lambda_event["Records"][0].get("eventName")

s3TriggerSpan = tracer.start_span(span_name, context=parent_context, kind=SpanKind.PRODUCER)
s3TriggerSpan = tracer.start_span(span_name, context=upstream_context, kind=SpanKind.PRODUCER)
s3TriggerSpan.set_attribute(SpanAttributes.FAAS_TRIGGER, "datasource")
s3TriggerSpan.set_attribute("faas.trigger.type", "S3")

triggerSpan = s3TriggerSpan
parent_context = set_span_in_context(s3TriggerSpan)
trigger_context = set_span_in_context(s3TriggerSpan)

if lambda_event["Records"][0].get("s3"):
s3TriggerSpan.set_attribute(
Expand Down Expand Up @@ -446,7 +451,7 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches
links.append(Link(span_ctx))

span_name = orig_handler_name
sqsTriggerSpan = tracer.start_span(span_name, context=parent_context, kind=SpanKind.CONSUMER, links=links)
sqsTriggerSpan = tracer.start_span(span_name, context=upstream_context, kind=SpanKind.CONSUMER, links=links)
sqsTriggerSpan.set_attribute(SpanAttributes.FAAS_TRIGGER, "pubsub")
sqsTriggerSpan.set_attribute("faas.trigger.type", "SQS")
sqsTriggerSpan.set_attribute(SpanAttributes.MESSAGING_SYSTEM, "aws.sqs")
Expand All @@ -459,7 +464,7 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches
pass

triggerSpan = sqsTriggerSpan
parent_context = set_span_in_context(sqsTriggerSpan)
trigger_context = set_span_in_context(sqsTriggerSpan)

if lambda_event["Records"][0].get("body"):
sqsTriggerSpan.set_attribute(
Expand Down Expand Up @@ -491,7 +496,7 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches

span_kind = SpanKind.INTERNAL
span_name = orig_handler_name
snsTriggerSpan = tracer.start_span(span_name, context=parent_context, kind=SpanKind.CONSUMER, links=links)
snsTriggerSpan = tracer.start_span(span_name, context=upstream_context, kind=SpanKind.CONSUMER, links=links)
snsTriggerSpan.set_attribute(SpanAttributes.FAAS_TRIGGER, "pubsub")
snsTriggerSpan.set_attribute("faas.trigger.type", "SNS")
snsTriggerSpan.set_attribute(SpanAttributes.MESSAGING_SYSTEM, "aws.sns")
Expand All @@ -504,7 +509,7 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches
pass

triggerSpan = snsTriggerSpan
parent_context = set_span_in_context(snsTriggerSpan)
trigger_context = set_span_in_context(snsTriggerSpan)

if lambda_event["Records"][0]["Sns"] and lambda_event["Records"][0]["Sns"].get("Message"):
snsTriggerSpan.set_attribute(
Expand Down Expand Up @@ -538,7 +543,7 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches
links.append(Link(span_ctx))
span_kind = SpanKind.INTERNAL
span_name = orig_handler_name
kinesisTriggerSpan = tracer.start_span(span_name, context=parent_context, kind=SpanKind.CONSUMER, links=links)
kinesisTriggerSpan = tracer.start_span(span_name, context=upstream_context, kind=SpanKind.CONSUMER, links=links)
kinesisTriggerSpan.set_attribute(SpanAttributes.FAAS_TRIGGER, "pubsub")
kinesisTriggerSpan.set_attribute("faas.trigger.type", "Kinesis")
kinesisTriggerSpan.set_attribute(SpanAttributes.MESSAGING_SYSTEM, "aws.kinesis")
Expand All @@ -551,7 +556,7 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches
pass

triggerSpan = kinesisTriggerSpan
parent_context = set_span_in_context(kinesisTriggerSpan)
trigger_context = set_span_in_context(kinesisTriggerSpan)

if lambda_event["Records"][0]["kinesis"] and lambda_event["Records"][0]["kinesis"].get("data"):
decoded_bytes = base64.b64decode(lambda_event["Records"][0]["kinesis"].get("data"))
Expand All @@ -572,12 +577,12 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches
if lambda_event["Records"][0].get("eventName"):
span_name = lambda_event["Records"][0].get("eventName")

dynamoTriggerSpan = tracer.start_span(span_name, context=parent_context, kind=SpanKind.PRODUCER)
dynamoTriggerSpan = tracer.start_span(span_name, context=upstream_context, kind=SpanKind.PRODUCER)
dynamoTriggerSpan.set_attribute(SpanAttributes.FAAS_TRIGGER, "datasource")
dynamoTriggerSpan.set_attribute("faas.trigger.type", "Dynamo DB")

triggerSpan = dynamoTriggerSpan
parent_context = set_span_in_context(dynamoTriggerSpan)
trigger_context = set_span_in_context(dynamoTriggerSpan)

if lambda_event["Records"][0].get("dynamodb"):
dynamoTriggerSpan.set_attribute(
Expand All @@ -594,12 +599,12 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches
if lambda_event.get("eventType"):
span_name = lambda_event.get("eventType")

cognitoTriggerSpan = tracer.start_span(span_name, context=parent_context, kind=SpanKind.PRODUCER)
cognitoTriggerSpan = tracer.start_span(span_name, context=upstream_context, kind=SpanKind.PRODUCER)
cognitoTriggerSpan.set_attribute(SpanAttributes.FAAS_TRIGGER, "datasource")
cognitoTriggerSpan.set_attribute("faas.trigger.type", "Cognito")

triggerSpan = cognitoTriggerSpan
parent_context = set_span_in_context(cognitoTriggerSpan)
trigger_context = set_span_in_context(cognitoTriggerSpan)

if lambda_event["datasetRecords"]:
cognitoTriggerSpan.set_attribute(
Expand All @@ -623,13 +628,13 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches
if span_ctx.span_id != INVALID_SPAN_ID:
links.append(Link(span_ctx))

eventBridgeTriggerSpan = tracer.start_span(span_name, context=parent_context, kind=SpanKind.CONSUMER, links=links)
eventBridgeTriggerSpan = tracer.start_span(span_name, context=upstream_context, kind=SpanKind.CONSUMER, links=links)
eventBridgeTriggerSpan.set_attribute(SpanAttributes.FAAS_TRIGGER, "pubsub")
eventBridgeTriggerSpan.set_attribute("faas.trigger.type", "EventBridge")
eventBridgeTriggerSpan.set_attribute("aws.event.bridge.trigger.source", lambda_event.get("source"))

triggerSpan = eventBridgeTriggerSpan
parent_context = set_span_in_context(eventBridgeTriggerSpan)
trigger_context = set_span_in_context(eventBridgeTriggerSpan)

eventBridgeTriggerSpan.set_attribute(
"rpc.request.body",
Expand All @@ -642,14 +647,28 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches
triggerSpan.set_attribute("cx.internal.span.role", "trigger")

try:
if trigger_context is not None:
invocation_parent_context = trigger_context
else:
invocation_parent_context = upstream_context

invocationSpan = tracer.start_span(
name=orig_handler_name,
context=parent_context,
context=invocation_parent_context,
kind=span_kind,
)
invocationSpan.set_attribute("cx.internal.span.role", "invocation")

_sendEarlySpans(flush_timeout, tracer, tracer_provider, meter_provider, triggerSpan, invocationSpan)
_sendEarlySpans(
flush_timeout,
tracer,
tracer_provider,
meter_provider,
trigger_parent_context=upstream_context,
trigger_span=triggerSpan,
invocation_parent_context=invocation_parent_context,
invocation_span=invocationSpan,
)

with use_span(
span=invocationSpan,
Expand Down Expand Up @@ -705,7 +724,6 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches
SpanAttributes.HTTP_URL,
lambda_event["requestContext"].get("domainName") + lambda_event["requestContext"].get("http").get("path")
)
apiGwSpan.end()
try:
if lambda_event["Records"][0]["eventSource"] == "aws:sqs":
span.set_attribute(SpanAttributes.FAAS_TRIGGER, "pubsub")
Expand Down Expand Up @@ -738,7 +756,6 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches
)
except Exception:
pass
s3TriggerSpan.end()

# SQS trigger response attributes
if lambda_event and sqsTriggerSpan is not None:
Expand All @@ -755,7 +772,6 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches
)
except Exception:
pass
sqsTriggerSpan.end()

if lambda_event and snsTriggerSpan is not None:
try:
Expand All @@ -771,7 +787,6 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches
)
except Exception:
pass
snsTriggerSpan.end()

if lambda_event and kinesisTriggerSpan is not None:
try:
Expand All @@ -783,7 +798,6 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches
)
except Exception:
pass
kinesisTriggerSpan.end()

if lambda_event and dynamoTriggerSpan is not None:
try:
Expand All @@ -799,7 +813,6 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches
)
except Exception:
pass
dynamoTriggerSpan.end()

if lambda_event and cognitoTriggerSpan is not None:
try:
Expand All @@ -815,7 +828,6 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches
)
except Exception:
pass
cognitoTriggerSpan.end()

if lambda_event and eventBridgeTriggerSpan is not None:
try:
Expand All @@ -831,14 +843,13 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches
)
except Exception:
pass
eventBridgeTriggerSpan.end()

except Exception as e:
if triggerSpan is not None:
triggerSpan.end()
raise e

finally:
if triggerSpan is not None:
triggerSpan.end()
_flush(flush_timeout, tracer_provider, meter_provider)

return result
Expand All @@ -854,27 +865,44 @@ def _sendEarlySpans(
tracer: Tracer,
tracer_provider: TracerProvider,
meter_provider: MeterProvider,
trigger_parent_context: Context,
trigger_span: Span,
invocation_parent_context: Context,
invocation_span: Span,
) -> None:
if trigger_span is not None:
early_trigger = _createEarlySpan(tracer, trigger_span)
early_trigger = _createEarlySpan(
tracer,
parent_context=trigger_parent_context,
span=trigger_span
)
early_trigger.end()

if invocation_span is not None:
early_invocation = _createEarlySpan(tracer, invocation_span)
early_invocation = _createEarlySpan(
tracer,
parent_context=invocation_parent_context,
span=invocation_span
)
early_invocation.end()

_flush(flush_timeout, tracer_provider, meter_provider)

def _createEarlySpan(
tracer: Tracer,
parent_context: Context,
span: Span,
) -> Span:
early_span = tracer.start_span(name=span.name, kind=span.kind, attributes=span.attributes)
early_span = tracer.start_span(
name=span.name,
context=parent_context,
kind=span.kind,
attributes=span.attributes,
links = span.links
)
early_span.set_attribute("cx.internal.span.state", "early")
early_span.set_attribute("cx.internal.trace.id", span.get_span_context().trace_id)
early_span.set_attribute("cx.internal.span.id", span.get_span_context().span_id)
early_span.set_attribute("cx.internal.trace.id", format_trace_id(span.get_span_context().trace_id))
early_span.set_attribute("cx.internal.span.id", format_span_id(span.get_span_context().span_id))
return early_span

def _flush(
Expand Down

0 comments on commit dfb4aa3

Please sign in to comment.