Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hints to Redis #3110

Merged
merged 3 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,9 @@ def response_hook(span, instance, response):
---
"""

import typing
from typing import Any, Collection
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, Collection

import redis
from wrapt import wrap_function_wrapper
Expand All @@ -109,18 +110,43 @@ def response_hook(span, instance, response):
from opentelemetry.instrumentation.redis.version import __version__
from opentelemetry.instrumentation.utils import unwrap
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace import Span, StatusCode
from opentelemetry.trace import Span, StatusCode, Tracer

_DEFAULT_SERVICE = "redis"
if TYPE_CHECKING:
from typing import Awaitable, TypeVar

_RequestHookT = typing.Optional[
typing.Callable[
[Span, redis.connection.Connection, typing.List, typing.Dict], None
import redis.asyncio.client
import redis.asyncio.cluster
import redis.client
import redis.cluster
import redis.connection

_RequestHookT = Callable[
[Span, redis.connection.Connection, list[Any], dict[str, Any]], None
]
]
_ResponseHookT = typing.Optional[
typing.Callable[[Span, redis.connection.Connection, Any], None]
]
_ResponseHookT = Callable[[Span, redis.connection.Connection, Any], None]

AsyncPipelineInstance = TypeVar(
"AsyncPipelineInstance",
redis.asyncio.client.Pipeline,
redis.asyncio.cluster.ClusterPipeline,
)
AsyncRedisInstance = TypeVar(
"AsyncRedisInstance", redis.asyncio.Redis, redis.asyncio.RedisCluster
)
PipelineInstance = TypeVar(
"PipelineInstance",
redis.client.Pipeline,
redis.cluster.ClusterPipeline,
)
RedisInstance = TypeVar(
"RedisInstance", redis.client.Redis, redis.cluster.RedisCluster
)
R = TypeVar("R")


_DEFAULT_SERVICE = "redis"


_REDIS_ASYNCIO_VERSION = (4, 2, 0)
if redis.VERSION >= _REDIS_ASYNCIO_VERSION:
Expand All @@ -132,7 +158,9 @@ def response_hook(span, instance, response):
_FIELD_TYPES = ["NUMERIC", "TEXT", "GEO", "TAG", "VECTOR"]


def _set_connection_attributes(span, conn):
def _set_connection_attributes(
span: Span, conn: RedisInstance | AsyncRedisInstance
) -> None:
if not span.is_recording() or not hasattr(conn, "connection_pool"):
return
for key, value in _extract_conn_attributes(
Expand All @@ -141,7 +169,9 @@ def _set_connection_attributes(span, conn):
span.set_attribute(key, value)


def _build_span_name(instance, cmd_args):
def _build_span_name(
instance: RedisInstance | AsyncRedisInstance, cmd_args: tuple[Any, ...]
) -> str:
if len(cmd_args) > 0 and cmd_args[0]:
if cmd_args[0] == "FT.SEARCH":
name = "redis.search"
Expand All @@ -154,7 +184,9 @@ def _build_span_name(instance, cmd_args):
return name


def _build_span_meta_data_for_pipeline(instance):
def _build_span_meta_data_for_pipeline(
instance: PipelineInstance | AsyncPipelineInstance,
) -> tuple[list[Any], str, str]:
try:
command_stack = (
instance.command_stack
Expand Down Expand Up @@ -184,11 +216,16 @@ def _build_span_meta_data_for_pipeline(instance):

# pylint: disable=R0915
def _instrument(
tracer,
request_hook: _RequestHookT = None,
response_hook: _ResponseHookT = None,
tracer: Tracer,
request_hook: _RequestHookT | None = None,
response_hook: _ResponseHookT | None = None,
):
def _traced_execute_command(func, instance, args, kwargs):
def _traced_execute_command(
func: Callable[..., R],
instance: RedisInstance,
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> R:
query = _format_command_args(args)
name = _build_span_name(instance, args)
with tracer.start_as_current_span(
Expand All @@ -210,7 +247,12 @@ def _traced_execute_command(func, instance, args, kwargs):
response_hook(span, instance, response)
return response

def _traced_execute_pipeline(func, instance, args, kwargs):
def _traced_execute_pipeline(
func: Callable[..., R],
instance: PipelineInstance,
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> R:
(
command_stack,
resource,
Expand Down Expand Up @@ -242,7 +284,7 @@ def _traced_execute_pipeline(func, instance, args, kwargs):

return response

def _add_create_attributes(span, args):
def _add_create_attributes(span: Span, args: tuple[Any, ...]):
_set_span_attribute_if_value(
span, "redis.create_index.index", _value_or_none(args, 1)
)
Expand All @@ -266,7 +308,7 @@ def _add_create_attributes(span, args):
field_attribute,
)

def _add_search_attributes(span, response, args):
def _add_search_attributes(span: Span, response, args):
_set_span_attribute_if_value(
span, "redis.search.index", _value_or_none(args, 1)
)
Expand Down Expand Up @@ -326,7 +368,12 @@ def _add_search_attributes(span, response, args):
_traced_execute_pipeline,
)

async def _async_traced_execute_command(func, instance, args, kwargs):
async def _async_traced_execute_command(
func: Callable[..., Awaitable[R]],
instance: AsyncRedisInstance,
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> Awaitable[R]:
query = _format_command_args(args)
name = _build_span_name(instance, args)

Expand All @@ -344,7 +391,12 @@ async def _async_traced_execute_command(func, instance, args, kwargs):
response_hook(span, instance, response)
return response

async def _async_traced_execute_pipeline(func, instance, args, kwargs):
async def _async_traced_execute_pipeline(
func: Callable[..., Awaitable[R]],
instance: AsyncPipelineInstance,
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> Awaitable[R]:
(
command_stack,
resource,
Expand Down Expand Up @@ -408,14 +460,15 @@ async def _async_traced_execute_pipeline(func, instance, args, kwargs):


class RedisInstrumentor(BaseInstrumentor):
"""An instrumentor for Redis
"""An instrumentor for Redis.

See `BaseInstrumentor`
"""

def instrumentation_dependencies(self) -> Collection[str]:
return _instruments

def _instrument(self, **kwargs):
def _instrument(self, **kwargs: Any):
"""Instruments the redis module

Args:
Expand All @@ -436,7 +489,7 @@ def _instrument(self, **kwargs):
response_hook=kwargs.get("response_hook"),
)

def _uninstrument(self, **kwargs):
def _uninstrument(self, **kwargs: Any):
if redis.VERSION < (3, 0, 0):
unwrap(redis.StrictRedis, "execute_command")
unwrap(redis.StrictRedis, "pipeline")
Expand Down
Loading