From 77d5ee1aa7737d9491d3ac5c940b40de4882369e Mon Sep 17 00:00:00 2001 From: Erle Carrara Date: Thu, 24 Oct 2024 17:07:20 -0300 Subject: [PATCH] Support functools.partial functions in AsyncioInstrumentor.trace_to_thread (#2911) --- CHANGELOG.md | 3 ++ .../instrumentation/asyncio/__init__.py | 12 ++++--- .../tests/test_asyncio_to_thread.py | 35 +++++++++++++++++++ 3 files changed, 45 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 028b4ee63e..0079547454 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -80,6 +80,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ([#2753](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2753)) - `opentelemetry-instrumentation-grpc` Fix grpc supported version ([#2845](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2845)) +- `opentelemetry-instrumentation-asyncio` fix `AttributeError` in + `AsyncioInstrumentor.trace_to_thread` when `func` is a `functools.partial` instance + ([#2911](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2911)) ## Version 1.26.0/0.47b0 (2024-07-23) diff --git a/instrumentation/opentelemetry-instrumentation-asyncio/src/opentelemetry/instrumentation/asyncio/__init__.py b/instrumentation/opentelemetry-instrumentation-asyncio/src/opentelemetry/instrumentation/asyncio/__init__.py index a6cc6b044f..2d1b063dfd 100644 --- a/instrumentation/opentelemetry-instrumentation-asyncio/src/opentelemetry/instrumentation/asyncio/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-asyncio/src/opentelemetry/instrumentation/asyncio/__init__.py @@ -78,6 +78,7 @@ def func(): """ import asyncio +import functools import sys from asyncio import futures from timeit import default_timer @@ -231,14 +232,15 @@ def wrap_taskgroup_create_task(method, instance, args, kwargs) -> None: def trace_to_thread(self, func: callable): """Trace a function.""" start = default_timer() + func_name = getattr(func, "__name__", None) + if func_name is None and isinstance(func, functools.partial): + func_name = func.func.__name__ span = ( - self._tracer.start_span( - f"{ASYNCIO_PREFIX} to_thread-" + func.__name__ - ) - if func.__name__ in self._to_thread_name_to_trace + self._tracer.start_span(f"{ASYNCIO_PREFIX} to_thread-" + func_name) + if func_name in self._to_thread_name_to_trace else None ) - attr = {"type": "to_thread", "name": func.__name__} + attr = {"type": "to_thread", "name": func_name} exception = None try: attr["state"] = "finished" diff --git a/instrumentation/opentelemetry-instrumentation-asyncio/tests/test_asyncio_to_thread.py b/instrumentation/opentelemetry-instrumentation-asyncio/tests/test_asyncio_to_thread.py index 3d795d8ae7..35191d3d03 100644 --- a/instrumentation/opentelemetry-instrumentation-asyncio/tests/test_asyncio_to_thread.py +++ b/instrumentation/opentelemetry-instrumentation-asyncio/tests/test_asyncio_to_thread.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import asyncio +import functools import sys from unittest import skipIf from unittest.mock import patch @@ -72,3 +73,37 @@ async def to_thread(): for point in metric.data.data_points: self.assertEqual(point.attributes["type"], "to_thread") self.assertEqual(point.attributes["name"], "multiply") + + @skipIf( + sys.version_info < (3, 9), "to_thread is only available in Python 3.9+" + ) + def test_to_thread_partial_func(self): + def multiply(x, y): + return x * y + + double = functools.partial(multiply, 2) + + async def to_thread(): + result = await asyncio.to_thread(double, 3) + assert result == 6 + + with self._tracer.start_as_current_span("root"): + asyncio.run(to_thread()) + spans = self.memory_exporter.get_finished_spans() + + self.assertEqual(len(spans), 2) + assert spans[0].name == "asyncio to_thread-multiply" + for metric in ( + self.memory_metrics_reader.get_metrics_data() + .resource_metrics[0] + .scope_metrics[0] + .metrics + ): + if metric.name == "asyncio.process.duration": + for point in metric.data.data_points: + self.assertEqual(point.attributes["type"], "to_thread") + self.assertEqual(point.attributes["name"], "multiply") + if metric.name == "asyncio.process.created": + for point in metric.data.data_points: + self.assertEqual(point.attributes["type"], "to_thread") + self.assertEqual(point.attributes["name"], "multiply")