diff --git a/CHANGELOG.md b/CHANGELOG.md index ed4671d559..7597e60641 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -41,6 +41,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ([#2871](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2871)) - `opentelemetry-instrumentation` Don't fail distro loading if instrumentor raises ImportError, instead skip them ([#2923](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2923)) +- `opentelemetry-instrumentation-httpx` Rewrote instrumentation to use wrapt instead of subclassing + ([#2909](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2909)) ## Version 1.27.0/0.48b0 (2024-08-28) diff --git a/instrumentation/opentelemetry-instrumentation-httpx/pyproject.toml b/instrumentation/opentelemetry-instrumentation-httpx/pyproject.toml index 599091716b..c986fac4a1 100644 --- a/instrumentation/opentelemetry-instrumentation-httpx/pyproject.toml +++ b/instrumentation/opentelemetry-instrumentation-httpx/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "opentelemetry-instrumentation == 0.49b0.dev", "opentelemetry-semantic-conventions == 0.49b0.dev", "opentelemetry-util-http == 0.49b0.dev", + "wrapt >= 1.0.0, < 2.0.0", ] [project.optional-dependencies] diff --git a/instrumentation/opentelemetry-instrumentation-httpx/src/opentelemetry/instrumentation/httpx/__init__.py b/instrumentation/opentelemetry-instrumentation-httpx/src/opentelemetry/instrumentation/httpx/__init__.py index b9b9a31d3e..d3a2cecfe6 100644 --- a/instrumentation/opentelemetry-instrumentation-httpx/src/opentelemetry/instrumentation/httpx/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-httpx/src/opentelemetry/instrumentation/httpx/__init__.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# pylint: disable=too-many-lines """ Usage ----- @@ -194,9 +195,11 @@ async def async_response_hook(span, request, response): import logging import typing from asyncio import iscoroutinefunction +from functools import partial from types import TracebackType import httpx +from wrapt import wrap_function_wrapper from opentelemetry.instrumentation._semconv import ( _get_schema_url, @@ -217,6 +220,7 @@ async def async_response_hook(span, request, response): from opentelemetry.instrumentation.utils import ( http_status_to_status_code, is_http_instrumentation_enabled, + unwrap, ) from opentelemetry.propagate import inject from opentelemetry.semconv.attributes.error_attributes import ERROR_TYPE @@ -731,44 +735,211 @@ def _instrument(self, **kwargs): ``async_request_hook``: Async ``request_hook`` for ``httpx.AsyncClient`` ``async_response_hook``: Async``response_hook`` for ``httpx.AsyncClient`` """ - self._original_client = httpx.Client - self._original_async_client = httpx.AsyncClient + tracer_provider = kwargs.get("tracer_provider") request_hook = kwargs.get("request_hook") response_hook = kwargs.get("response_hook") async_request_hook = kwargs.get("async_request_hook") - async_response_hook = kwargs.get("async_response_hook") - if callable(request_hook): - _InstrumentedClient._request_hook = request_hook - if callable(async_request_hook) and iscoroutinefunction( + async_request_hook = ( async_request_hook - ): - _InstrumentedAsyncClient._request_hook = async_request_hook - if callable(response_hook): - _InstrumentedClient._response_hook = response_hook - if callable(async_response_hook) and iscoroutinefunction( + if iscoroutinefunction(async_request_hook) + else None + ) + async_response_hook = kwargs.get("async_response_hook") + async_response_hook = ( async_response_hook - ): - _InstrumentedAsyncClient._response_hook = async_response_hook - tracer_provider = kwargs.get("tracer_provider") - _InstrumentedClient._tracer_provider = tracer_provider - _InstrumentedAsyncClient._tracer_provider = tracer_provider - # Intentionally using a private attribute here, see: - # https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2538#discussion_r1610603719 - httpx.Client = httpx._api.Client = _InstrumentedClient - httpx.AsyncClient = _InstrumentedAsyncClient + if iscoroutinefunction(async_response_hook) + else None + ) + + _OpenTelemetrySemanticConventionStability._initialize() + sem_conv_opt_in_mode = _OpenTelemetrySemanticConventionStability._get_opentelemetry_stability_opt_in_mode( + _OpenTelemetryStabilitySignalType.HTTP, + ) + tracer = get_tracer( + __name__, + instrumenting_library_version=__version__, + tracer_provider=tracer_provider, + schema_url=_get_schema_url(sem_conv_opt_in_mode), + ) + + wrap_function_wrapper( + "httpx", + "HTTPTransport.handle_request", + partial( + self._handle_request_wrapper, + tracer=tracer, + sem_conv_opt_in_mode=sem_conv_opt_in_mode, + request_hook=request_hook, + response_hook=response_hook, + ), + ) + wrap_function_wrapper( + "httpx", + "AsyncHTTPTransport.handle_async_request", + partial( + self._handle_async_request_wrapper, + tracer=tracer, + sem_conv_opt_in_mode=sem_conv_opt_in_mode, + async_request_hook=async_request_hook, + async_response_hook=async_response_hook, + ), + ) def _uninstrument(self, **kwargs): - httpx.Client = httpx._api.Client = self._original_client - httpx.AsyncClient = self._original_async_client - _InstrumentedClient._tracer_provider = None - _InstrumentedClient._request_hook = None - _InstrumentedClient._response_hook = None - _InstrumentedAsyncClient._tracer_provider = None - _InstrumentedAsyncClient._request_hook = None - _InstrumentedAsyncClient._response_hook = None + unwrap(httpx.HTTPTransport, "handle_request") + unwrap(httpx.AsyncHTTPTransport, "handle_async_request") @staticmethod + def _handle_request_wrapper( # pylint: disable=too-many-locals + wrapped, + instance, + args, + kwargs, + tracer, + sem_conv_opt_in_mode, + request_hook, + response_hook, + ): + if not is_http_instrumentation_enabled(): + return wrapped(*args, **kwargs) + + method, url, headers, stream, extensions = _extract_parameters( + args, kwargs + ) + method_original = method.decode() + span_name = _get_default_span_name(method_original) + span_attributes = {} + # apply http client response attributes according to semconv + _apply_request_client_attributes_to_span( + span_attributes, + url, + method_original, + sem_conv_opt_in_mode, + ) + + request_info = RequestInfo(method, url, headers, stream, extensions) + + with tracer.start_as_current_span( + span_name, kind=SpanKind.CLIENT, attributes=span_attributes + ) as span: + exception = None + if callable(request_hook): + request_hook(span, request_info) + + _inject_propagation_headers(headers, args, kwargs) + + try: + response = wrapped(*args, **kwargs) + except Exception as exc: # pylint: disable=W0703 + exception = exc + response = getattr(exc, "response", None) + + if isinstance(response, (httpx.Response, tuple)): + status_code, headers, stream, extensions, http_version = ( + _extract_response(response) + ) + + if span.is_recording(): + # apply http client response attributes according to semconv + _apply_response_client_attributes_to_span( + span, + status_code, + http_version, + sem_conv_opt_in_mode, + ) + if callable(response_hook): + response_hook( + span, + request_info, + ResponseInfo(status_code, headers, stream, extensions), + ) + + if exception: + if span.is_recording() and _report_new(sem_conv_opt_in_mode): + span.set_attribute( + ERROR_TYPE, type(exception).__qualname__ + ) + raise exception.with_traceback(exception.__traceback__) + + return response + + @staticmethod + async def _handle_async_request_wrapper( # pylint: disable=too-many-locals + wrapped, + instance, + args, + kwargs, + tracer, + sem_conv_opt_in_mode, + async_request_hook, + async_response_hook, + ): + if not is_http_instrumentation_enabled(): + return await wrapped(*args, **kwargs) + + method, url, headers, stream, extensions = _extract_parameters( + args, kwargs + ) + method_original = method.decode() + span_name = _get_default_span_name(method_original) + span_attributes = {} + # apply http client response attributes according to semconv + _apply_request_client_attributes_to_span( + span_attributes, + url, + method_original, + sem_conv_opt_in_mode, + ) + + request_info = RequestInfo(method, url, headers, stream, extensions) + + with tracer.start_as_current_span( + span_name, kind=SpanKind.CLIENT, attributes=span_attributes + ) as span: + exception = None + if callable(async_request_hook): + await async_request_hook(span, request_info) + + _inject_propagation_headers(headers, args, kwargs) + + try: + response = await wrapped(*args, **kwargs) + except Exception as exc: # pylint: disable=W0703 + exception = exc + response = getattr(exc, "response", None) + + if isinstance(response, (httpx.Response, tuple)): + status_code, headers, stream, extensions, http_version = ( + _extract_response(response) + ) + + if span.is_recording(): + # apply http client response attributes according to semconv + _apply_response_client_attributes_to_span( + span, + status_code, + http_version, + sem_conv_opt_in_mode, + ) + + if callable(async_response_hook): + await async_response_hook( + span, + request_info, + ResponseInfo(status_code, headers, stream, extensions), + ) + + if exception: + if span.is_recording() and _report_new(sem_conv_opt_in_mode): + span.set_attribute( + ERROR_TYPE, type(exception).__qualname__ + ) + raise exception.with_traceback(exception.__traceback__) + + return response + def instrument_client( + self, client: typing.Union[httpx.Client, httpx.AsyncClient], tracer_provider: TracerProvider = None, request_hook: typing.Union[ @@ -788,67 +959,88 @@ def instrument_client( response_hook: A hook that receives the span, request, and response that is called right before the span ends """ - # pylint: disable=protected-access - if not hasattr(client, "_is_instrumented_by_opentelemetry"): - client._is_instrumented_by_opentelemetry = False - if not client._is_instrumented_by_opentelemetry: - if isinstance(client, httpx.Client): - client._original_transport = client._transport - client._original_mounts = client._mounts.copy() - transport = client._transport or httpx.HTTPTransport() - client._transport = SyncOpenTelemetryTransport( - transport, - tracer_provider=tracer_provider, - request_hook=request_hook, - response_hook=response_hook, - ) - client._is_instrumented_by_opentelemetry = True - client._mounts.update( - { - url_pattern: ( - SyncOpenTelemetryTransport( - transport, - tracer_provider=tracer_provider, - request_hook=request_hook, - response_hook=response_hook, - ) - if transport is not None - else transport - ) - for url_pattern, transport in client._original_mounts.items() - } - ) + if getattr(client, "_is_instrumented_by_opentelemetry", False): + _logger.warning( + "Attempting to instrument Httpx client while already instrumented" + ) + return - if isinstance(client, httpx.AsyncClient): - transport = client._transport or httpx.AsyncHTTPTransport() - client._original_mounts = client._mounts.copy() - client._transport = AsyncOpenTelemetryTransport( - transport, - tracer_provider=tracer_provider, + _OpenTelemetrySemanticConventionStability._initialize() + sem_conv_opt_in_mode = _OpenTelemetrySemanticConventionStability._get_opentelemetry_stability_opt_in_mode( + _OpenTelemetryStabilitySignalType.HTTP, + ) + tracer = get_tracer( + __name__, + instrumenting_library_version=__version__, + tracer_provider=tracer_provider, + schema_url=_get_schema_url(sem_conv_opt_in_mode), + ) + + if iscoroutinefunction(request_hook): + async_request_hook = request_hook + request_hook = None + else: + # request_hook already set + async_request_hook = None + + if iscoroutinefunction(response_hook): + async_response_hook = response_hook + response_hook = None + else: + # response_hook already set + async_response_hook = None + + if hasattr(client._transport, "handle_request"): + wrap_function_wrapper( + client._transport, + "handle_request", + partial( + self._handle_request_wrapper, + tracer=tracer, + sem_conv_opt_in_mode=sem_conv_opt_in_mode, request_hook=request_hook, response_hook=response_hook, + ), + ) + for transport in client._mounts.values(): + wrap_function_wrapper( + transport, + "handle_request", + partial( + self._handle_request_wrapper, + tracer=tracer, + sem_conv_opt_in_mode=sem_conv_opt_in_mode, + request_hook=request_hook, + response_hook=response_hook, + ), ) - client._is_instrumented_by_opentelemetry = True - client._mounts.update( - { - url_pattern: ( - AsyncOpenTelemetryTransport( - transport, - tracer_provider=tracer_provider, - request_hook=request_hook, - response_hook=response_hook, - ) - if transport is not None - else transport - ) - for url_pattern, transport in client._original_mounts.items() - } - ) - else: - _logger.warning( - "Attempting to instrument Httpx client while already instrumented" + client._is_instrumented_by_opentelemetry = True + if hasattr(client._transport, "handle_async_request"): + wrap_function_wrapper( + client._transport, + "handle_async_request", + partial( + self._handle_async_request_wrapper, + tracer=tracer, + sem_conv_opt_in_mode=sem_conv_opt_in_mode, + async_request_hook=async_request_hook, + async_response_hook=async_response_hook, + ), ) + for transport in client._mounts.values(): + wrap_function_wrapper( + transport, + "handle_async_request", + partial( + self._handle_async_request_wrapper, + tracer=tracer, + sem_conv_opt_in_mode=sem_conv_opt_in_mode, + async_request_hook=async_request_hook, + async_response_hook=async_response_hook, + ), + ) + client._is_instrumented_by_opentelemetry = True @staticmethod def uninstrument_client( @@ -859,15 +1051,13 @@ def uninstrument_client( Args: client: The httpx Client or AsyncClient instance """ - if hasattr(client, "_original_transport"): - client._transport = client._original_transport - del client._original_transport + if hasattr(client._transport, "handle_request"): + unwrap(client._transport, "handle_request") + for transport in client._mounts.values(): + unwrap(transport, "handle_request") + client._is_instrumented_by_opentelemetry = False + elif hasattr(client._transport, "handle_async_request"): + unwrap(client._transport, "handle_async_request") + for transport in client._mounts.values(): + unwrap(transport, "handle_async_request") client._is_instrumented_by_opentelemetry = False - if hasattr(client, "_original_mounts"): - client._mounts = client._original_mounts.copy() - del client._original_mounts - else: - _logger.warning( - "Attempting to uninstrument Httpx " - "client while already uninstrumented" - ) diff --git a/instrumentation/opentelemetry-instrumentation-httpx/tests/test_httpx_integration.py b/instrumentation/opentelemetry-instrumentation-httpx/tests/test_httpx_integration.py index 0d055515e0..07699700c4 100644 --- a/instrumentation/opentelemetry-instrumentation-httpx/tests/test_httpx_integration.py +++ b/instrumentation/opentelemetry-instrumentation-httpx/tests/test_httpx_integration.py @@ -21,6 +21,7 @@ import httpx import respx +from wrapt import ObjectProxy import opentelemetry.instrumentation.httpx from opentelemetry import trace @@ -171,6 +172,7 @@ def tearDown(self): super().tearDown() self.env_patch.stop() respx.stop() + HTTPXClientInstrumentor().uninstrument() def assert_span( self, exporter: "SpanExporter" = None, num_spans: int = 1 @@ -204,7 +206,7 @@ def test_basic(self): self.assertEqual(span.name, "GET") self.assertEqual( - span.attributes, + dict(span.attributes), { SpanAttributes.HTTP_METHOD: "GET", SpanAttributes.HTTP_URL: self.URL, @@ -228,7 +230,7 @@ def test_nonstandard_http_method(self): self.assertIs(span.kind, trace.SpanKind.CLIENT) self.assertEqual(span.name, "HTTP") self.assertEqual( - span.attributes, + dict(span.attributes), { SpanAttributes.HTTP_METHOD: "_OTHER", SpanAttributes.HTTP_URL: self.URL, @@ -252,7 +254,7 @@ def test_nonstandard_http_method_new_semconv(self): self.assertIs(span.kind, trace.SpanKind.CLIENT) self.assertEqual(span.name, "HTTP") self.assertEqual( - span.attributes, + dict(span.attributes), { HTTP_REQUEST_METHOD: "_OTHER", URL_FULL: self.URL, @@ -292,7 +294,7 @@ def test_basic_new_semconv(self): SpanAttributes.SCHEMA_URL, ) self.assertEqual( - span.attributes, + dict(span.attributes), { HTTP_REQUEST_METHOD: "GET", URL_FULL: url, @@ -327,7 +329,7 @@ def test_basic_both_semconv(self): ) self.assertEqual( - span.attributes, + dict(span.attributes), { SpanAttributes.HTTP_METHOD: "GET", HTTP_REQUEST_METHOD: "GET", @@ -454,7 +456,7 @@ def test_requests_500_error(self): span = self.assert_span() self.assertEqual( - span.attributes, + dict(span.attributes), { SpanAttributes.HTTP_METHOD: "GET", SpanAttributes.HTTP_URL: self.URL, @@ -510,7 +512,7 @@ def test_requests_timeout_exception_new_semconv(self): span = self.assert_span() self.assertEqual( - span.attributes, + dict(span.attributes), { HTTP_REQUEST_METHOD: "GET", URL_FULL: url, @@ -531,7 +533,7 @@ def test_requests_timeout_exception_both_semconv(self): span = self.assert_span() self.assertEqual( - span.attributes, + dict(span.attributes), { SpanAttributes.HTTP_METHOD: "GET", HTTP_REQUEST_METHOD: "GET", @@ -632,7 +634,7 @@ def test_response_hook(self): self.assertEqual(result.text, "Hello!") span = self.assert_span() self.assertEqual( - span.attributes, + dict(span.attributes), { SpanAttributes.HTTP_METHOD: "GET", SpanAttributes.HTTP_URL: self.URL, @@ -741,8 +743,10 @@ def create_proxy_transport(self, url: str): def setUp(self): super().setUp() - HTTPXClientInstrumentor().instrument() self.client = self.create_client() + HTTPXClientInstrumentor().instrument_client(self.client) + + def tearDown(self): HTTPXClientInstrumentor().uninstrument() def create_proxy_mounts(self): @@ -755,14 +759,25 @@ def create_proxy_mounts(self): ), } - def assert_proxy_mounts(self, mounts, num_mounts, transport_type): + def assert_proxy_mounts(self, mounts, num_mounts, transport_type=None): self.assertEqual(len(mounts), num_mounts) for transport in mounts: with self.subTest(transport): - self.assertIsInstance( - transport, - transport_type, - ) + if transport_type: + self.assertIsInstance( + transport, + transport_type, + ) + else: + handler = getattr(transport, "handle_request", None) + if not handler: + handler = getattr( + transport, "handle_async_request" + ) + self.assertTrue( + isinstance(handler, ObjectProxy) + and getattr(handler, "__wrapped__") + ) def test_custom_tracer_provider(self): resource = resources.Resource.create({}) @@ -778,7 +793,6 @@ def test_custom_tracer_provider(self): self.assertEqual(result.text, "Hello!") span = self.assert_span(exporter=exporter) self.assertIs(span.resource, resource) - HTTPXClientInstrumentor().uninstrument() def test_response_hook(self): response_hook_key = ( @@ -797,7 +811,7 @@ def test_response_hook(self): self.assertEqual(result.text, "Hello!") span = self.assert_span() self.assertEqual( - span.attributes, + dict(span.attributes), { SpanAttributes.HTTP_METHOD: "GET", SpanAttributes.HTTP_URL: self.URL, @@ -805,7 +819,6 @@ def test_response_hook(self): HTTP_RESPONSE_BODY: "Hello!", }, ) - HTTPXClientInstrumentor().uninstrument() def test_response_hook_sync_async_kwargs(self): HTTPXClientInstrumentor().instrument( @@ -819,7 +832,7 @@ def test_response_hook_sync_async_kwargs(self): self.assertEqual(result.text, "Hello!") span = self.assert_span() self.assertEqual( - span.attributes, + dict(span.attributes), { SpanAttributes.HTTP_METHOD: "GET", SpanAttributes.HTTP_URL: self.URL, @@ -827,7 +840,6 @@ def test_response_hook_sync_async_kwargs(self): HTTP_RESPONSE_BODY: "Hello!", }, ) - HTTPXClientInstrumentor().uninstrument() def test_request_hook(self): request_hook_key = ( @@ -846,7 +858,6 @@ def test_request_hook(self): self.assertEqual(result.text, "Hello!") span = self.assert_span() self.assertEqual(span.name, "GET" + self.URL) - HTTPXClientInstrumentor().uninstrument() def test_request_hook_sync_async_kwargs(self): HTTPXClientInstrumentor().instrument( @@ -860,7 +871,6 @@ def test_request_hook_sync_async_kwargs(self): self.assertEqual(result.text, "Hello!") span = self.assert_span() self.assertEqual(span.name, "GET" + self.URL) - HTTPXClientInstrumentor().uninstrument() def test_request_hook_no_span_update(self): HTTPXClientInstrumentor().instrument( @@ -873,7 +883,6 @@ def test_request_hook_no_span_update(self): self.assertEqual(result.text, "Hello!") span = self.assert_span() self.assertEqual(span.name, "GET") - HTTPXClientInstrumentor().uninstrument() def test_not_recording(self): with mock.patch("opentelemetry.trace.INVALID_SPAN") as mock_span: @@ -891,7 +900,6 @@ def test_not_recording(self): self.assertTrue(mock_span.is_recording.called) self.assertFalse(mock_span.set_attribute.called) self.assertFalse(mock_span.set_status.called) - HTTPXClientInstrumentor().uninstrument() def test_suppress_instrumentation_new_client(self): HTTPXClientInstrumentor().instrument() @@ -901,7 +909,6 @@ def test_suppress_instrumentation_new_client(self): self.assertEqual(result.text, "Hello!") self.assert_span(num_spans=0) - HTTPXClientInstrumentor().uninstrument() def test_instrument_client(self): client = self.create_client() @@ -929,8 +936,6 @@ def test_instrumentation_without_client(self): self.URL, ) - HTTPXClientInstrumentor().uninstrument() - def test_uninstrument(self): HTTPXClientInstrumentor().instrument() HTTPXClientInstrumentor().uninstrument() @@ -980,9 +985,7 @@ def test_instrument_proxy(self): self.assert_proxy_mounts( client._mounts.values(), 2, - (SyncOpenTelemetryTransport, AsyncOpenTelemetryTransport), ) - HTTPXClientInstrumentor().uninstrument() def test_instrument_client_with_proxy(self): proxy_mounts = self.create_proxy_mounts() @@ -999,7 +1002,6 @@ def test_instrument_client_with_proxy(self): self.assert_proxy_mounts( client._mounts.values(), 2, - (SyncOpenTelemetryTransport, AsyncOpenTelemetryTransport), ) HTTPXClientInstrumentor().uninstrument_client(client) @@ -1010,7 +1012,6 @@ def test_uninstrument_client_with_proxy(self): self.assert_proxy_mounts( client._mounts.values(), 2, - (SyncOpenTelemetryTransport, AsyncOpenTelemetryTransport), ) HTTPXClientInstrumentor().uninstrument_client(client) @@ -1180,6 +1181,21 @@ def perform_request( def create_proxy_transport(self, url): return httpx.HTTPTransport(proxy=httpx.Proxy(url)) + def test_can_instrument_subclassed_client(self): + class CustomClient(httpx.Client): + pass + + client = CustomClient() + self.assertFalse( + isinstance(client._transport.handle_request, ObjectProxy) + ) + + HTTPXClientInstrumentor().instrument() + + self.assertTrue( + isinstance(client._transport.handle_request, ObjectProxy) + ) + class TestAsyncInstrumentationIntegration(BaseTestCases.BaseInstrumentorTest): response_hook = staticmethod(_async_response_hook) @@ -1188,10 +1204,8 @@ class TestAsyncInstrumentationIntegration(BaseTestCases.BaseInstrumentorTest): def setUp(self): super().setUp() - HTTPXClientInstrumentor().instrument() - self.client = self.create_client() self.client2 = self.create_client() - HTTPXClientInstrumentor().uninstrument() + HTTPXClientInstrumentor().instrument_client(self.client2) def create_client( self, @@ -1245,7 +1259,6 @@ def test_async_response_hook_does_nothing_if_not_coroutine(self): SpanAttributes.HTTP_STATUS_CODE: 200, }, ) - HTTPXClientInstrumentor().uninstrument() def test_async_request_hook_does_nothing_if_not_coroutine(self): HTTPXClientInstrumentor().instrument( @@ -1258,4 +1271,18 @@ def test_async_request_hook_does_nothing_if_not_coroutine(self): self.assertEqual(result.text, "Hello!") span = self.assert_span() self.assertEqual(span.name, "GET") - HTTPXClientInstrumentor().uninstrument() + + def test_can_instrument_subclassed_async_client(self): + class CustomAsyncClient(httpx.AsyncClient): + pass + + client = CustomAsyncClient() + self.assertFalse( + isinstance(client._transport.handle_async_request, ObjectProxy) + ) + + HTTPXClientInstrumentor().instrument() + + self.assertTrue( + isinstance(client._transport.handle_async_request, ObjectProxy) + )