Skip to content

Commit

Permalink
added decorators to keyvault-keys (#6381)
Browse files Browse the repository at this point in the history
  • Loading branch information
SuyogSoti authored Jul 23, 2019
1 parent 0e0d17e commit 3ab986a
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 30 deletions.
25 changes: 8 additions & 17 deletions sdk/core/azure-core/tests/test_tracing_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,37 +4,27 @@
# ------------------------------------
"""The tests for decorators.py and common.py"""

import unittest

try:
from unittest import mock
except ImportError:
import mock

import sys
import os
import time

import pytest
from azure.core import HttpRequest
from azure.core.pipeline import Pipeline, PipelineResponse
from azure.core.pipeline.policies import HTTPPolicy
from azure.core.pipeline.transport import HttpTransport
from azure.core.settings import settings
from azure.core.tracing import common
from azure.core.tracing.context import tracing_context
from azure.core.tracing.decorator import distributed_trace
from azure.core.settings import settings
from azure.core.tracing.ext.opencensus_span import OpenCensusSpan
from opencensus.trace import tracer as tracer_module
from opencensus.trace.samplers import AlwaysOnSampler
from tracing_common import ContextHelper, MockExporter
import time
import pytest

try:
from typing import TYPE_CHECKING
except ImportError:
TYPE_CHECKING = False

if TYPE_CHECKING:
from typing import List


class MockClient:
Expand Down Expand Up @@ -73,6 +63,7 @@ def get_foo(self):
time.sleep(0.001)
return 5


class TestCommon(object):
def test_set_span_context(self):
with ContextHelper(environ={"AZURE_SDK_TRACING_IMPLEMENTATION": "opencensus"}):
Expand Down Expand Up @@ -113,11 +104,11 @@ def test_get_parent_span(self):
def test_should_use_trace(self):
with ContextHelper(environ={"AZURE_TRACING_ONLY_PROPAGATE": "yes"}):
parent_span = OpenCensusSpan()
assert common.should_use_trace(parent_span) == False
assert common.should_use_trace(None) == False
assert not common.should_use_trace(parent_span)
assert not common.should_use_trace(None)
parent_span = OpenCensusSpan()
assert common.should_use_trace(parent_span)
assert common.should_use_trace(None) == False
assert not common.should_use_trace(None)


class TestDecorator(object):
Expand Down
2 changes: 1 addition & 1 deletion sdk/core/azure-core/tests/test_tracing_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from opencensus.trace.span import SpanKind
from opencensus.trace.samplers import AlwaysOnSampler
from opencensus.trace.base_exporter import Exporter
from opencensus.common.utils import timestamp_to_microseconds
from tracing_common import MockExporter, ContextHelper
import os

Expand Down Expand Up @@ -79,6 +78,7 @@ def test_start_finish(self):
assert wrapped_class.span_instance.start_time is not None
assert wrapped_class.span_instance.end_time is not None
parent.finish()
tracer.finish()

def test_to_header(self):
with ContextHelper() as ctx:
Expand Down
18 changes: 7 additions & 11 deletions sdk/core/azure-core/tests/tracing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,14 @@
# ------------------------------------
"""Code shared between the async and the sync test_decorator files."""

import sys
import os
from azure.core import HttpRequest
from azure.core.pipeline import Pipeline, PipelineResponse
from azure.core.pipeline.policies import HTTPPolicy
from azure.core.pipeline.transport import HttpTransport
from azure.core.tracing import common
from azure.core.tracing.context import tracing_context

from azure.core.settings import settings
from azure.core.tracing.context import tracing_context
from azure.core.tracing.ext.opencensus_span import OpenCensusSpan
from opencensus.trace import tracer as tracer_module
from opencensus.trace.span_data import SpanData
from opencensus.trace.samplers import AlwaysOnSampler
from opencensus.trace import execution_context
from opencensus.trace.base_exporter import Exporter
from opencensus.trace.span_data import SpanData
from collections import defaultdict
from opencensus.trace import execution_context

Expand Down Expand Up @@ -47,6 +41,8 @@ def __enter__(self):
if self.should_only_propagate is not None:
settings.tracing_should_only_propagate.set_value(self.should_only_propagate)
self.os_env.start()
execution_context.clear()
tracing_context.current_span.clear()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
Expand All @@ -67,7 +63,7 @@ def __init__(self, span_data):

class MockExporter(Exporter):
def __init__(self):
self.root = None
self.root = None # type: SpanData
self._all_nodes = []
self.parent_dict = defaultdict(list)

Expand Down
18 changes: 18 additions & 0 deletions sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Any, Dict, Generator, Mapping, Optional

from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
from azure.core.tracing.decorator import distributed_trace

from ._shared import KeyVaultClientBase
from ._models import Key, KeyBase, DeletedKey, KeyOperationResult
Expand All @@ -32,6 +33,7 @@ class KeyClient(KeyVaultClientBase):

# pylint:disable=protected-access

@distributed_trace
def create_key(
self,
name,
Expand Down Expand Up @@ -107,6 +109,7 @@ def create_key(
)
return Key._from_key_bundle(bundle)

@distributed_trace
def create_rsa_key(
self,
name,
Expand Down Expand Up @@ -171,6 +174,7 @@ def create_rsa_key(
**kwargs
)

@distributed_trace
def create_ec_key(
self,
name,
Expand Down Expand Up @@ -238,6 +242,7 @@ def create_ec_key(
**kwargs
)

@distributed_trace
def delete_key(self, name, **kwargs):
# type: (str, Mapping[str, Any]) -> DeletedKey
"""Deletes a key from the Key Vault.
Expand Down Expand Up @@ -265,6 +270,7 @@ def delete_key(self, name, **kwargs):
bundle = self._client.delete_key(self.vault_url, name, error_map={404: ResourceNotFoundError}, **kwargs)
return DeletedKey._from_deleted_key_bundle(bundle)

@distributed_trace
def get_key(self, name, version=None, **kwargs):
# type: (str, Optional[str], Mapping[str, Any]) -> Key
"""Gets the public part of a stored key.
Expand Down Expand Up @@ -295,6 +301,7 @@ def get_key(self, name, version=None, **kwargs):
)
return Key._from_key_bundle(bundle)

@distributed_trace
def get_deleted_key(self, name, **kwargs):
# type: (str, Mapping[str, Any]) -> DeletedKey
"""Gets a deleted key from the Key Vault
Expand All @@ -320,6 +327,7 @@ def get_deleted_key(self, name, **kwargs):
bundle = self._client.get_deleted_key(self.vault_url, name, error_map={404: ResourceNotFoundError}, **kwargs)
return DeletedKey._from_deleted_key_bundle(bundle)

@distributed_trace
def list_deleted_keys(self, **kwargs):
# type: (Mapping[str, Any]) -> Generator[DeletedKey]
"""Lists the deleted keys in the Key Vault
Expand Down Expand Up @@ -348,6 +356,7 @@ def list_deleted_keys(self, **kwargs):
pages = self._client.get_deleted_keys(self._vault_url, maxresults=max_page_size, **kwargs)
return (DeletedKey._from_deleted_key_item(item) for item in pages)

@distributed_trace
def list_keys(self, **kwargs):
# type: (Mapping[str, Any]) -> Generator[KeyBase]
"""List the keys in the Key Vault
Expand Down Expand Up @@ -375,6 +384,7 @@ def list_keys(self, **kwargs):
pages = self._client.get_keys(self._vault_url, maxresults=max_page_size, **kwargs)
return (KeyBase._from_key_item(item) for item in pages)

@distributed_trace
def list_key_versions(self, name, **kwargs):
# type: (str, Mapping[str, Any]) -> Generator[KeyBase]
"""Retrieves a list of individual key versions with the same key name.
Expand All @@ -400,6 +410,7 @@ def list_key_versions(self, name, **kwargs):
pages = self._client.get_key_versions(self._vault_url, name, maxresults=max_page_size, **kwargs)
return (KeyBase._from_key_item(item) for item in pages)

@distributed_trace
def purge_deleted_key(self, name, **kwargs):
# type: (str, Mapping[str, Any]) -> None
"""Permanently deletes the specified key.
Expand All @@ -424,6 +435,7 @@ def purge_deleted_key(self, name, **kwargs):
"""
self._client.purge_deleted_key(self.vault_url, name, kwargs)

@distributed_trace
def recover_deleted_key(self, name, **kwargs):
# type: (str, Mapping[str, Any]) -> Key
"""Recovers the deleted key to its latest version.
Expand Down Expand Up @@ -451,6 +463,7 @@ def recover_deleted_key(self, name, **kwargs):
bundle = self._client.recover_deleted_key(self.vault_url, name, kwargs)
return Key._from_key_bundle(bundle)

@distributed_trace
def update_key(
self, name, version=None, key_operations=None, enabled=None, expires=None, not_before=None, tags=None, **kwargs
):
Expand Down Expand Up @@ -508,6 +521,7 @@ def update_key(
)
return Key._from_key_bundle(bundle)

@distributed_trace
def backup_key(self, name, **kwargs):
# type: (str, Mapping[str, Any]) -> bytes
"""Backs up the specified key.
Expand Down Expand Up @@ -545,6 +559,7 @@ def backup_key(self, name, **kwargs):
backup_result = self._client.backup_key(self.vault_url, name, error_map={404: ResourceNotFoundError}, **kwargs)
return backup_result.value

@distributed_trace
def restore_key(self, backup, **kwargs):
# type: (bytes, Mapping[str, Any]) -> Key
"""Restores a backed up key to the Key Vault
Expand Down Expand Up @@ -581,6 +596,7 @@ def restore_key(self, backup, **kwargs):
bundle = self._client.restore_key(self.vault_url, backup, error_map={409: ResourceExistsError}, **kwargs)
return Key._from_key_bundle(bundle)

@distributed_trace
def import_key(self, name, key, hsm=None, enabled=None, not_before=None, expires=None, tags=None, **kwargs):
# type: (str, List[str], Optional[bool], Optional[bool], Optional[datetime], Optional[datetime], Optional[Dict[str, str]], Mapping[str, Any]) -> Key
"""Imports an externally created key, stores it, and returns the key to the client.
Expand Down Expand Up @@ -618,6 +634,7 @@ def import_key(self, name, key, hsm=None, enabled=None, not_before=None, expires
)
return Key._from_key_bundle(bundle)

@distributed_trace
def wrap_key(self, name, algorithm, value, version=None, **kwargs):
# type: (str, str, Optional[str], bytes, Mapping[str, Any]) -> KeyOperationResult
"""Wraps a symmetric key using a specified key.
Expand Down Expand Up @@ -653,6 +670,7 @@ def wrap_key(self, name, algorithm, value, version=None, **kwargs):
)
return KeyOperationResult(id=bundle.kid, value=bundle.result)

@distributed_trace
def unwrap_key(self, name, algorithm, value, version=None, **kwargs):
# type: (str, str, Optional[str], bytes, Mapping[str, Any]) -> KeyOperationResult
"""Unwraps a symmetric key using the specified key that was initially used
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from azure.core.async_paging import AsyncPagedMixin
from azure.core.configuration import Configuration
from azure.core.pipeline import AsyncPipeline
from azure.core.pipeline.policies.distributed_tracing import DistributedTracingPolicy
from azure.core.pipeline.transport import AsyncioRequestsTransport, HttpTransport
from msrest.serialization import Model

Expand Down Expand Up @@ -101,6 +102,7 @@ def _build_pipeline(config: Configuration, transport: HttpTransport, **kwargs: A
config.retry_policy,
config.authentication_policy,
config.logging_policy,
DistributedTracingPolicy()
]

if transport is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from azure.core import Configuration
from azure.core.pipeline import Pipeline
from azure.core.pipeline.transport import RequestsTransport
from azure.core.pipeline.policies.distributed_tracing import DistributedTracingPolicy
from ._generated import KeyVaultClient

if TYPE_CHECKING:
Expand Down Expand Up @@ -73,6 +74,7 @@ def _build_pipeline(self, config, transport, **kwargs):
config.retry_policy,
config.authentication_policy,
config.logging_policy,
DistributedTracingPolicy(),
]

if transport is None:
Expand Down
Loading

0 comments on commit 3ab986a

Please sign in to comment.