diff --git a/sdk/core/azure-core/tests/test_tracing_decorator.py b/sdk/core/azure-core/tests/test_tracing_decorator.py index a1e95c63b588..385bda8e4f92 100644 --- a/sdk/core/azure-core/tests/test_tracing_decorator.py +++ b/sdk/core/azure-core/tests/test_tracing_decorator.py @@ -4,37 +4,27 @@ # ------------------------------------ """The tests for decorators.py and common.py""" -import unittest - try: from unittest import mock except ImportError: import mock import sys -import os +import time + +import pytest from azure.core import HttpRequest from azure.core.pipeline import Pipeline, PipelineResponse from azure.core.pipeline.policies import HTTPPolicy from azure.core.pipeline.transport import HttpTransport +from azure.core.settings import settings from azure.core.tracing import common from azure.core.tracing.context import tracing_context from azure.core.tracing.decorator import distributed_trace -from azure.core.settings import settings from azure.core.tracing.ext.opencensus_span import OpenCensusSpan from opencensus.trace import tracer as tracer_module from opencensus.trace.samplers import AlwaysOnSampler from tracing_common import ContextHelper, MockExporter -import time -import pytest - -try: - from typing import TYPE_CHECKING -except ImportError: - TYPE_CHECKING = False - -if TYPE_CHECKING: - from typing import List class MockClient: @@ -73,6 +63,7 @@ def get_foo(self): time.sleep(0.001) return 5 + class TestCommon(object): def test_set_span_context(self): with ContextHelper(environ={"AZURE_SDK_TRACING_IMPLEMENTATION": "opencensus"}): @@ -113,11 +104,11 @@ def test_get_parent_span(self): def test_should_use_trace(self): with ContextHelper(environ={"AZURE_TRACING_ONLY_PROPAGATE": "yes"}): parent_span = OpenCensusSpan() - assert common.should_use_trace(parent_span) == False - assert common.should_use_trace(None) == False + assert not common.should_use_trace(parent_span) + assert not common.should_use_trace(None) parent_span = OpenCensusSpan() assert common.should_use_trace(parent_span) - assert common.should_use_trace(None) == False + assert not common.should_use_trace(None) class TestDecorator(object): diff --git a/sdk/core/azure-core/tests/test_tracing_implementations.py b/sdk/core/azure-core/tests/test_tracing_implementations.py index 23df8cff33a6..07f9348004da 100644 --- a/sdk/core/azure-core/tests/test_tracing_implementations.py +++ b/sdk/core/azure-core/tests/test_tracing_implementations.py @@ -16,7 +16,6 @@ from opencensus.trace.span import SpanKind from opencensus.trace.samplers import AlwaysOnSampler from opencensus.trace.base_exporter import Exporter -from opencensus.common.utils import timestamp_to_microseconds from tracing_common import MockExporter, ContextHelper import os @@ -79,6 +78,7 @@ def test_start_finish(self): assert wrapped_class.span_instance.start_time is not None assert wrapped_class.span_instance.end_time is not None parent.finish() + tracer.finish() def test_to_header(self): with ContextHelper() as ctx: diff --git a/sdk/core/azure-core/tests/tracing_common.py b/sdk/core/azure-core/tests/tracing_common.py index a1dfabdca8e8..18f9b4e9f839 100644 --- a/sdk/core/azure-core/tests/tracing_common.py +++ b/sdk/core/azure-core/tests/tracing_common.py @@ -4,20 +4,14 @@ # ------------------------------------ """Code shared between the async and the sync test_decorator files.""" -import sys import os -from azure.core import HttpRequest -from azure.core.pipeline import Pipeline, PipelineResponse -from azure.core.pipeline.policies import HTTPPolicy -from azure.core.pipeline.transport import HttpTransport -from azure.core.tracing import common -from azure.core.tracing.context import tracing_context + from azure.core.settings import settings +from azure.core.tracing.context import tracing_context from azure.core.tracing.ext.opencensus_span import OpenCensusSpan -from opencensus.trace import tracer as tracer_module -from opencensus.trace.span_data import SpanData -from opencensus.trace.samplers import AlwaysOnSampler +from opencensus.trace import execution_context from opencensus.trace.base_exporter import Exporter +from opencensus.trace.span_data import SpanData from collections import defaultdict from opencensus.trace import execution_context @@ -47,6 +41,8 @@ def __enter__(self): if self.should_only_propagate is not None: settings.tracing_should_only_propagate.set_value(self.should_only_propagate) self.os_env.start() + execution_context.clear() + tracing_context.current_span.clear() return self def __exit__(self, exc_type, exc_val, exc_tb): @@ -67,7 +63,7 @@ def __init__(self, span_data): class MockExporter(Exporter): def __init__(self): - self.root = None + self.root = None # type: SpanData self._all_nodes = [] self.parent_dict = defaultdict(list) diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_client.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_client.py index 90aeb33200d9..ee0e81d07f73 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_client.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_client.py @@ -13,6 +13,7 @@ from typing import Any, Dict, Generator, Mapping, Optional from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError +from azure.core.tracing.decorator import distributed_trace from ._shared import KeyVaultClientBase from ._models import Key, KeyBase, DeletedKey, KeyOperationResult @@ -32,6 +33,7 @@ class KeyClient(KeyVaultClientBase): # pylint:disable=protected-access + @distributed_trace def create_key( self, name, @@ -107,6 +109,7 @@ def create_key( ) return Key._from_key_bundle(bundle) + @distributed_trace def create_rsa_key( self, name, @@ -171,6 +174,7 @@ def create_rsa_key( **kwargs ) + @distributed_trace def create_ec_key( self, name, @@ -238,6 +242,7 @@ def create_ec_key( **kwargs ) + @distributed_trace def delete_key(self, name, **kwargs): # type: (str, Mapping[str, Any]) -> DeletedKey """Deletes a key from the Key Vault. @@ -265,6 +270,7 @@ def delete_key(self, name, **kwargs): bundle = self._client.delete_key(self.vault_url, name, error_map={404: ResourceNotFoundError}, **kwargs) return DeletedKey._from_deleted_key_bundle(bundle) + @distributed_trace def get_key(self, name, version=None, **kwargs): # type: (str, Optional[str], Mapping[str, Any]) -> Key """Gets the public part of a stored key. @@ -295,6 +301,7 @@ def get_key(self, name, version=None, **kwargs): ) return Key._from_key_bundle(bundle) + @distributed_trace def get_deleted_key(self, name, **kwargs): # type: (str, Mapping[str, Any]) -> DeletedKey """Gets a deleted key from the Key Vault @@ -320,6 +327,7 @@ def get_deleted_key(self, name, **kwargs): bundle = self._client.get_deleted_key(self.vault_url, name, error_map={404: ResourceNotFoundError}, **kwargs) return DeletedKey._from_deleted_key_bundle(bundle) + @distributed_trace def list_deleted_keys(self, **kwargs): # type: (Mapping[str, Any]) -> Generator[DeletedKey] """Lists the deleted keys in the Key Vault @@ -348,6 +356,7 @@ def list_deleted_keys(self, **kwargs): pages = self._client.get_deleted_keys(self._vault_url, maxresults=max_page_size, **kwargs) return (DeletedKey._from_deleted_key_item(item) for item in pages) + @distributed_trace def list_keys(self, **kwargs): # type: (Mapping[str, Any]) -> Generator[KeyBase] """List the keys in the Key Vault @@ -375,6 +384,7 @@ def list_keys(self, **kwargs): pages = self._client.get_keys(self._vault_url, maxresults=max_page_size, **kwargs) return (KeyBase._from_key_item(item) for item in pages) + @distributed_trace def list_key_versions(self, name, **kwargs): # type: (str, Mapping[str, Any]) -> Generator[KeyBase] """Retrieves a list of individual key versions with the same key name. @@ -400,6 +410,7 @@ def list_key_versions(self, name, **kwargs): pages = self._client.get_key_versions(self._vault_url, name, maxresults=max_page_size, **kwargs) return (KeyBase._from_key_item(item) for item in pages) + @distributed_trace def purge_deleted_key(self, name, **kwargs): # type: (str, Mapping[str, Any]) -> None """Permanently deletes the specified key. @@ -424,6 +435,7 @@ def purge_deleted_key(self, name, **kwargs): """ self._client.purge_deleted_key(self.vault_url, name, kwargs) + @distributed_trace def recover_deleted_key(self, name, **kwargs): # type: (str, Mapping[str, Any]) -> Key """Recovers the deleted key to its latest version. @@ -451,6 +463,7 @@ def recover_deleted_key(self, name, **kwargs): bundle = self._client.recover_deleted_key(self.vault_url, name, kwargs) return Key._from_key_bundle(bundle) + @distributed_trace def update_key( self, name, version=None, key_operations=None, enabled=None, expires=None, not_before=None, tags=None, **kwargs ): @@ -508,6 +521,7 @@ def update_key( ) return Key._from_key_bundle(bundle) + @distributed_trace def backup_key(self, name, **kwargs): # type: (str, Mapping[str, Any]) -> bytes """Backs up the specified key. @@ -545,6 +559,7 @@ def backup_key(self, name, **kwargs): backup_result = self._client.backup_key(self.vault_url, name, error_map={404: ResourceNotFoundError}, **kwargs) return backup_result.value + @distributed_trace def restore_key(self, backup, **kwargs): # type: (bytes, Mapping[str, Any]) -> Key """Restores a backed up key to the Key Vault @@ -581,6 +596,7 @@ def restore_key(self, backup, **kwargs): bundle = self._client.restore_key(self.vault_url, backup, error_map={409: ResourceExistsError}, **kwargs) return Key._from_key_bundle(bundle) + @distributed_trace def import_key(self, name, key, hsm=None, enabled=None, not_before=None, expires=None, tags=None, **kwargs): # type: (str, List[str], Optional[bool], Optional[bool], Optional[datetime], Optional[datetime], Optional[Dict[str, str]], Mapping[str, Any]) -> Key """Imports an externally created key, stores it, and returns the key to the client. @@ -618,6 +634,7 @@ def import_key(self, name, key, hsm=None, enabled=None, not_before=None, expires ) return Key._from_key_bundle(bundle) + @distributed_trace def wrap_key(self, name, algorithm, value, version=None, **kwargs): # type: (str, str, Optional[str], bytes, Mapping[str, Any]) -> KeyOperationResult """Wraps a symmetric key using a specified key. @@ -653,6 +670,7 @@ def wrap_key(self, name, algorithm, value, version=None, **kwargs): ) return KeyOperationResult(id=bundle.kid, value=bundle.result) + @distributed_trace def unwrap_key(self, name, algorithm, value, version=None, **kwargs): # type: (str, str, Optional[str], bytes, Mapping[str, Any]) -> KeyOperationResult """Unwraps a symmetric key using the specified key that was initially used diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_client_base.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_client_base.py index da245f70bd07..b60a51f5244c 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_client_base.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_client_base.py @@ -6,6 +6,7 @@ from azure.core.async_paging import AsyncPagedMixin from azure.core.configuration import Configuration from azure.core.pipeline import AsyncPipeline +from azure.core.pipeline.policies.distributed_tracing import DistributedTracingPolicy from azure.core.pipeline.transport import AsyncioRequestsTransport, HttpTransport from msrest.serialization import Model @@ -101,6 +102,7 @@ def _build_pipeline(config: Configuration, transport: HttpTransport, **kwargs: A config.retry_policy, config.authentication_policy, config.logging_policy, + DistributedTracingPolicy() ] if transport is None: diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/client_base.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/client_base.py index a5ac40189656..bb249c861e28 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/client_base.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/client_base.py @@ -6,6 +6,7 @@ from azure.core import Configuration from azure.core.pipeline import Pipeline from azure.core.pipeline.transport import RequestsTransport +from azure.core.pipeline.policies.distributed_tracing import DistributedTracingPolicy from ._generated import KeyVaultClient if TYPE_CHECKING: @@ -73,6 +74,7 @@ def _build_pipeline(self, config, transport, **kwargs): config.retry_policy, config.authentication_policy, config.logging_policy, + DistributedTracingPolicy(), ] if transport is None: diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/aio/_client.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/aio/_client.py index 9f81bbdc5a7b..51f916d2938f 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/aio/_client.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/aio/_client.py @@ -6,7 +6,8 @@ from typing import Any, AsyncIterable, Mapping, Optional, Dict, List from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError - +from azure.core.tracing.decorator import distributed_trace +from azure.core.tracing.decorator_async import distributed_trace_async from azure.keyvault.keys._models import Key, DeletedKey, KeyBase, KeyOperationResult from azure.keyvault.keys._shared import AsyncKeyVaultClientBase, AsyncPagingAdapter @@ -25,6 +26,7 @@ class KeyClient(AsyncKeyVaultClientBase): # pylint:disable=protected-access + @distributed_trace_async async def get_key(self, name: str, version: Optional[str] = None, **kwargs: Mapping[str, Any]) -> Key: """Gets the public part of a stored key. @@ -57,6 +59,7 @@ async def get_key(self, name: str, version: Optional[str] = None, **kwargs: Mapp ) return Key._from_key_bundle(bundle) + @distributed_trace_async async def create_key( self, name: str, @@ -131,6 +134,7 @@ async def create_key( ) return Key._from_key_bundle(bundle) + @distributed_trace_async async def create_rsa_key( self, name: str, @@ -194,6 +198,7 @@ async def create_rsa_key( **kwargs, ) + @distributed_trace_async async def create_ec_key( self, name: str, @@ -259,6 +264,7 @@ async def create_ec_key( **kwargs, ) + @distributed_trace_async async def update_key( self, name: str, @@ -324,6 +330,7 @@ async def update_key( ) return Key._from_key_bundle(bundle) + @distributed_trace def list_keys(self, **kwargs: Mapping[str, Any]) -> AsyncIterable[KeyBase]: """List keys in the specified vault. @@ -351,6 +358,7 @@ def list_keys(self, **kwargs: Mapping[str, Any]) -> AsyncIterable[KeyBase]: iterable = AsyncPagingAdapter(pages, KeyBase._from_key_item) return iterable + @distributed_trace def list_key_versions(self, name: str, **kwargs: Mapping[str, Any]) -> AsyncIterable[KeyBase]: """Retrieves a list of individual key versions with the same key name. @@ -376,6 +384,7 @@ def list_key_versions(self, name: str, **kwargs: Mapping[str, Any]) -> AsyncIter iterable = AsyncPagingAdapter(pages, KeyBase._from_key_item) return iterable + @distributed_trace_async async def backup_key(self, name: str, **kwargs: Mapping[str, Any]) -> bytes: """Requests a backup of the specified key to the client. @@ -413,6 +422,7 @@ async def backup_key(self, name: str, **kwargs: Mapping[str, Any]) -> bytes: ) return backup_result.value + @distributed_trace_async async def restore_key(self, backup: bytes, **kwargs: Mapping[str, Any]) -> Key: """Restores a backed up key to a vault. @@ -448,6 +458,7 @@ async def restore_key(self, backup: bytes, **kwargs: Mapping[str, Any]) -> Key: bundle = await self._client.restore_key(self.vault_url, backup, error_map={409: ResourceExistsError}, **kwargs) return Key._from_key_bundle(bundle) + @distributed_trace_async async def delete_key(self, name: str, **kwargs: Mapping[str, Any]) -> DeletedKey: """Deletes a key from the Key Vault. @@ -474,6 +485,7 @@ async def delete_key(self, name: str, **kwargs: Mapping[str, Any]) -> DeletedKey bundle = await self._client.delete_key(self.vault_url, name, error_map={404: ResourceNotFoundError}, **kwargs) return DeletedKey._from_deleted_key_bundle(bundle) + @distributed_trace_async async def get_deleted_key(self, name: str, **kwargs: Mapping[str, Any]) -> DeletedKey: """Gets a deleted key from the Key Vault @@ -501,6 +513,7 @@ async def get_deleted_key(self, name: str, **kwargs: Mapping[str, Any]) -> Delet ) return DeletedKey._from_deleted_key_bundle(bundle) + @distributed_trace def list_deleted_keys(self, **kwargs: Mapping[str, Any]) -> AsyncIterable[DeletedKey]: """Lists the deleted keys in the specified vault. @@ -529,6 +542,7 @@ def list_deleted_keys(self, **kwargs: Mapping[str, Any]) -> AsyncIterable[Delete iterable = AsyncPagingAdapter(pages, DeletedKey._from_deleted_key_item) return iterable + @distributed_trace_async async def purge_deleted_key(self, name: str, **kwargs: Mapping[str, Any]) -> None: """Permanently deletes the specified key. @@ -552,6 +566,7 @@ async def purge_deleted_key(self, name: str, **kwargs: Mapping[str, Any]) -> Non """ await self._client.purge_deleted_key(self.vault_url, name, **kwargs) + @distributed_trace_async async def recover_deleted_key(self, name: str, **kwargs: Mapping[str, Any]) -> Key: """Recovers the deleted key to its latest version. @@ -578,6 +593,7 @@ async def recover_deleted_key(self, name: str, **kwargs: Mapping[str, Any]) -> K bundle = await self._client.recover_deleted_key(self.vault_url, name, **kwargs) return Key._from_key_bundle(bundle) + @distributed_trace_async async def import_key( self, name: str, @@ -624,6 +640,7 @@ async def import_key( ) return Key._from_key_bundle(bundle) + @distributed_trace_async async def wrap_key( self, name: str, algorithm: str, value: bytes, version: Optional[str] = None, **kwargs: Mapping[str, Any] ) -> KeyOperationResult: @@ -660,6 +677,7 @@ async def wrap_key( ) return KeyOperationResult(id=bundle.kid, value=bundle.result) + @distributed_trace_async async def unwrap_key( self, name: str, algorithm: str, value: bytes, version: Optional[str] = None, **kwargs: Mapping[str, Any] ) -> KeyOperationResult: