Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hints to Starlette instrumentation #3045

Merged
merged 7 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
emdneto marked this conversation as resolved.
Show resolved Hide resolved
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,11 @@ def client_response_hook(span: Span, scope: dict[str, Any], message: dict[str, A
API
---
"""
# pyright: reportPrivateUsage=false
Kludex marked this conversation as resolved.
Show resolved Hide resolved

from typing import Collection
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Collection, cast

from starlette import applications
from starlette.routing import Match
Expand All @@ -184,18 +187,29 @@ def client_response_hook(span: Span, scope: dict[str, Any], message: dict[str, A
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.instrumentation.starlette.package import _instruments
from opentelemetry.instrumentation.starlette.version import __version__
from opentelemetry.metrics import get_meter
from opentelemetry.metrics import MeterProvider, get_meter
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace import get_tracer
from opentelemetry.trace import TracerProvider, get_tracer
from opentelemetry.util.http import get_excluded_urls

if TYPE_CHECKING:
from typing import NotRequired, TypedDict, Unpack

class InstrumentKwargs(TypedDict):
Kludex marked this conversation as resolved.
Show resolved Hide resolved
tracer_provider: NotRequired[TracerProvider]
meter_provider: NotRequired[MeterProvider]
emdneto marked this conversation as resolved.
Show resolved Hide resolved
server_request_hook: NotRequired[ServerRequestHook]
client_request_hook: NotRequired[ClientRequestHook]
client_response_hook: NotRequired[ClientResponseHook]


_excluded_urls = get_excluded_urls("STARLETTE")


class StarletteInstrumentor(BaseInstrumentor):
"""An instrumentor for starlette
"""An instrumentor for Starlette.

See `BaseInstrumentor`
See `BaseInstrumentor`.
"""

_original_starlette = None
Expand All @@ -206,8 +220,8 @@ def instrument_app(
server_request_hook: ServerRequestHook = None,
client_request_hook: ClientRequestHook = None,
client_response_hook: ClientResponseHook = None,
meter_provider=None,
tracer_provider=None,
meter_provider: MeterProvider | None = None,
tracer_provider: TracerProvider | None = None,
):
"""Instrument an uninstrumented Starlette application."""
tracer = get_tracer(
Expand Down Expand Up @@ -253,7 +267,7 @@ def uninstrument_app(app: applications.Starlette):
def instrumentation_dependencies(self) -> Collection[str]:
return _instruments

def _instrument(self, **kwargs):
def _instrument(self, **kwargs: Unpack[InstrumentKwargs]):
emdneto marked this conversation as resolved.
Show resolved Hide resolved
self._original_starlette = applications.Starlette
_InstrumentedStarlette._tracer_provider = kwargs.get("tracer_provider")
_InstrumentedStarlette._server_request_hook = kwargs.get(
Expand All @@ -269,7 +283,7 @@ def _instrument(self, **kwargs):

applications.Starlette = _InstrumentedStarlette

def _uninstrument(self, **kwargs):
def _uninstrument(self, **kwargs: Any):
"""uninstrumenting all created apps by user"""
for instance in _InstrumentedStarlette._instrumented_starlette_apps:
self.uninstrument_app(instance)
Expand All @@ -278,14 +292,14 @@ def _uninstrument(self, **kwargs):


class _InstrumentedStarlette(applications.Starlette):
_tracer_provider = None
_meter_provider = None
_tracer_provider: TracerProvider | None = None
_meter_provider: MeterProvider | None = None
_server_request_hook: ServerRequestHook = None
_client_request_hook: ClientRequestHook = None
_client_response_hook: ClientResponseHook = None
_instrumented_starlette_apps = set()
_instrumented_starlette_apps: set[applications.Starlette] = set()

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
emdneto marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(*args, **kwargs)
tracer = get_tracer(
__name__,
Expand Down Expand Up @@ -318,21 +332,22 @@ def __del__(self):
_InstrumentedStarlette._instrumented_starlette_apps.remove(self)


def _get_route_details(scope):
def _get_route_details(scope: dict[str, Any]) -> str | None:
"""
Function to retrieve Starlette route from scope.
Function to retrieve Starlette route from ASGI scope.

TODO: there is currently no way to retrieve http.route from
a starlette application from scope.
See: https://github.com/encode/starlette/pull/804

Args:
scope: A Starlette scope
scope: The ASGI scope that contains the Starlette application in the "app" key.

Returns:
A string containing the route or None
The path to the route if found, otherwise None.
"""
app = scope["app"]
route = None
app = cast(applications.Starlette, scope["app"])
route: str | None = None

for starlette_route in app.routes:
match, _ = starlette_route.matches(scope)
Expand All @@ -344,18 +359,20 @@ def _get_route_details(scope):
return route


def _get_default_span_details(scope):
"""
Callback to retrieve span name and attributes from scope.
def _get_default_span_details(
scope: dict[str, Any],
) -> tuple[str, dict[str, Any]]:
Kludex marked this conversation as resolved.
Show resolved Hide resolved
"""Callback to retrieve span name and attributes from ASGI scope.

Args:
scope: A Starlette scope
scope: The ASGI scope that contains the Starlette application in the "app" key.

Returns:
A tuple of span name and attributes
A tuple of span name and attributes.
"""
route = _get_route_details(scope)
method = scope.get("method", "")
attributes = {}
method: str = scope.get("method", "")
attributes: dict[str, Any] = {}
if route:
attributes[SpanAttributes.HTTP_ROUTE] = route
if method and route: # http
Expand Down
Loading