Skip to content

Commit

Permalink
[KeyVault] Keyvault Keys to Test Proxy (#24165)
Browse files Browse the repository at this point in the history
* move conftest into the tests folder

* test proxy changes

* new recordings

* more recordings for crud

* sync test recordings

* move over to test proxy

* kv async recordings

* simple clean ups

* recordings

* clean up imports

* pick right vault name

* clean up

* fix test parse id offline test

* override pytest default event loop

* fix for async tests, change to aiohttp request

* remove commented code

* formatting fixes

* Delete vcrpy recordings

* with block for async client

* clean up

* code clean ups

* move keys specific methods in to a separate class

* PR comments

* refactor test to use preparer
  • Loading branch information
kashifkhan authored May 4, 2022
1 parent b8bcbd5 commit 9c28b76
Show file tree
Hide file tree
Showing 944 changed files with 417,642 additions and 320,902 deletions.
8 changes: 0 additions & 8 deletions sdk/keyvault/azure-keyvault-keys/conftest.py

This file was deleted.

117 changes: 117 additions & 0 deletions sdk/keyvault/azure-keyvault-keys/tests/_async_test_case.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import json
import os

import pytest
from azure.core.pipeline import AsyncPipeline
from azure.core.pipeline.transport import AioHttpTransport, HttpRequest
from azure.keyvault.keys import KeyReleasePolicy
from azure.keyvault.keys._shared.client_base import DEFAULT_VERSION, ApiVersion
from devtools_testutils import AzureRecordedTestCase


async def get_attestation_token(attestation_uri):
request = HttpRequest("GET", "{}/generate-test-token".format(attestation_uri))
async with AsyncPipeline(transport=AioHttpTransport()) as pipeline:
response = await pipeline.run(request)
return json.loads(response.http_response.text())["token"]


def get_decorator(only_hsm=False, only_vault=False, api_versions=None, **kwargs):
"""returns a test decorator for test parameterization"""
params = [
pytest.param(p[0],p[1], id=p[0] + ("_mhsm" if p[1] else "_vault" ))
for p in get_test_parameters(only_hsm, only_vault, api_versions=api_versions)
]
return params


def get_release_policy(attestation_uri, **kwargs):
release_policy_json = {
"anyOf": [
{
"anyOf": [
{
"claim": "sdk-test",
"equals": True
}
],
"authority": attestation_uri.rstrip("/") + "/"
}
],
"version": "1.0.0"
}
policy_string = json.dumps(release_policy_json).encode()
return KeyReleasePolicy(policy_string, **kwargs)


def get_test_parameters(only_hsm=False, only_vault=False, api_versions=None):
"""generates a list of parameter pairs for test case parameterization, where [x, y] = [api_version, is_hsm]"""
combinations = []
versions = api_versions or ApiVersion
hsm_supported_versions = {ApiVersion.V7_2, ApiVersion.V7_3}

for api_version in versions:
if not only_vault and api_version in hsm_supported_versions:
combinations.append([api_version, True])
if not only_hsm:
combinations.append([api_version, False])
return combinations


def is_public_cloud():
return (".microsoftonline.com" in os.getenv('AZURE_AUTHORITY_HOST', ''))


class AsyncKeysClientPreparer(AzureRecordedTestCase):
def __init__(self, *args, **kwargs):
vault_playback_url = "https://vaultname.vault.azure.net"
hsm_playback_url = "https://managedhsmvaultname.vault.azure.net"
self.is_logging_enabled = kwargs.pop("logging_enable", True)

if self.is_live:
self.vault_url = os.environ["AZURE_KEYVAULT_URL"]
self.managed_hsm_url = os.environ.get("AZURE_MANAGEDHSM_URL")
else:
self.vault_url = vault_playback_url
self.managed_hsm_url = hsm_playback_url

self._set_mgmt_settings_real_values()

def __call__(self, fn):
async def _preparer(test_class, api_version, is_hsm, **kwargs):

self._skip_if_not_configured(api_version, is_hsm)
if not self.is_logging_enabled:
kwargs.update({"logging_enable": False})
endpoint_url = self.managed_hsm_url if is_hsm else self.vault_url
client = self.create_key_client(endpoint_url, api_version=api_version, **kwargs)
async with client:
await fn(test_class, client, is_hsm=is_hsm, managed_hsm_url = self.managed_hsm_url, vault_url = self.vault_url)

return _preparer



def create_key_client(self, vault_uri, **kwargs):

from azure.keyvault.keys.aio import KeyClient

credential = self.get_credential(KeyClient, is_async=True)

return self.create_client_from_credential(KeyClient, credential=credential, vault_url=vault_uri, **kwargs)

def _set_mgmt_settings_real_values(self):
if self.is_live:
os.environ["AZURE_TENANT_ID"] = os.environ["KEYVAULT_TENANT_ID"]
os.environ["AZURE_CLIENT_ID"] = os.environ["KEYVAULT_CLIENT_ID"]
os.environ["AZURE_CLIENT_SECRET"] = os.environ["KEYVAULT_CLIENT_SECRET"]

def _skip_if_not_configured(self, api_version, is_hsm):
if self.is_live and api_version != DEFAULT_VERSION:
pytest.skip("This test only uses the default API version for live tests")
if self.is_live and is_hsm and self.managed_hsm_url is None:
pytest.skip("No HSM endpoint for live testing")
26 changes: 26 additions & 0 deletions sdk/keyvault/azure-keyvault-keys/tests/_keys_test_case.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os

import pytest
from devtools_testutils import AzureRecordedTestCase


class KeysTestCase(AzureRecordedTestCase):
def _get_attestation_uri(self):
playback_uri = "https://fakeattestation.azurewebsites.net"
if self.is_live:
real_uri = os.environ.get("AZURE_KEYVAULT_ATTESTATION_URL")
real_uri = real_uri.rstrip('/')
if real_uri is None:
pytest.skip("No AZURE_KEYVAULT_ATTESTATION_URL environment variable")
return real_uri
return playback_uri

def create_crypto_client(self, key, **kwargs):
if kwargs.pop("is_async", False):
from azure.keyvault.keys.crypto.aio import CryptographyClient
credential = self.get_credential(CryptographyClient,is_async=True)
else:
from azure.keyvault.keys.crypto import CryptographyClient
credential = self.get_credential(CryptographyClient)

return self.create_client_from_credential(CryptographyClient, credential=credential, key=key, **kwargs)
19 changes: 7 additions & 12 deletions sdk/keyvault/azure-keyvault-keys/tests/_shared/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,13 @@
# ------------------------------------
import time

from azure_devtools.scenario_tests.patches import patch_time_sleep_api
from devtools_testutils import AzureTestCase
from azure.keyvault.keys._shared import HttpChallengeCache
from devtools_testutils import AzureRecordedTestCase


class KeyVaultTestCase(AzureTestCase):
def __init__(self, *args, **kwargs):
if "match_body" not in kwargs:
kwargs["match_body"] = True

super(KeyVaultTestCase, self).__init__(*args, **kwargs)
self.replay_patches.append(patch_time_sleep_api)

def setUp(self):
self.list_test_size = 7
super(KeyVaultTestCase, self).setUp()

class KeyVaultTestCase(AzureRecordedTestCase):
def get_resource_name(self, name):
"""helper to create resources with a consistent, test-indicative prefix"""
return super(KeyVaultTestCase, self).get_resource_name("livekvtest{}".format(name))
Expand Down Expand Up @@ -48,3 +39,7 @@ def _poll_until_exception(self, fn, expected_exception, max_retries=20, retry_de
return

self.fail("expected exception {expected_exception} was not raised")

def teardown_method(self, method):
HttpChallengeCache.clear()
assert len(HttpChallengeCache._cache) == 0
25 changes: 7 additions & 18 deletions sdk/keyvault/azure-keyvault-keys/tests/_shared/test_case_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,11 @@
# ------------------------------------
import asyncio

from azure_devtools.scenario_tests.patches import mock_in_unit_test
from devtools_testutils import AzureTestCase
from devtools_testutils import AzureRecordedTestCase
from azure.keyvault.keys._shared import HttpChallengeCache


def skip_sleep(unit_test):
async def immediate_return(_):
return

return mock_in_unit_test(unit_test, "asyncio.sleep", immediate_return)


class KeyVaultTestCase(AzureTestCase):
def __init__(self, *args, match_body=True, **kwargs):
super().__init__(*args, match_body=match_body, **kwargs)
self.replay_patches.append(skip_sleep)

def setUp(self):
self.list_test_size = 7
super(KeyVaultTestCase, self).setUp()

class KeyVaultTestCase(AzureRecordedTestCase):
def get_resource_name(self, name):
"""helper to create resources with a consistent, test-indicative prefix"""
return super(KeyVaultTestCase, self).get_resource_name("livekvtest{}".format(name))
Expand Down Expand Up @@ -51,3 +36,7 @@ async def _poll_until_exception(self, fn, expected_exception, max_retries=20, re
except expected_exception:
return
self.fail("expected exception {expected_exception} was not raised")

def teardown_method(self, method):
HttpChallengeCache.clear()
assert len(HttpChallengeCache._cache) == 0
110 changes: 29 additions & 81 deletions sdk/keyvault/azure-keyvault-keys/tests/_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,40 +2,15 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import functools
import json
import os

import pytest
from azure.core.pipeline import Pipeline
from azure.core.pipeline.transport import HttpRequest, RequestsTransport
from azure.keyvault.keys import KeyReleasePolicy
from azure.keyvault.keys._shared import HttpChallengeCache
from azure.keyvault.keys._shared.client_base import ApiVersion, DEFAULT_VERSION
from devtools_testutils import AzureTestCase
from parameterized import parameterized, param
import pytest
from six.moves.urllib_parse import urlparse


def client_setup(testcase_func):
"""decorator that creates a client to be passed in to a test method"""

@functools.wraps(testcase_func)
def wrapper(test_class_instance, api_version, is_hsm=False, **kwargs):
test_class_instance._skip_if_not_configured(api_version, is_hsm)
endpoint_url = test_class_instance.managed_hsm_url if is_hsm else test_class_instance.vault_url
client = test_class_instance.create_key_client(endpoint_url, api_version=api_version, **kwargs)

if kwargs.get("is_async"):
import asyncio

coroutine = testcase_func(test_class_instance, client, is_hsm=is_hsm)
loop = asyncio.get_event_loop()
loop.run_until_complete(coroutine)
else:
testcase_func(test_class_instance, client, is_hsm=is_hsm)

return wrapper
from azure.keyvault.keys._shared.client_base import DEFAULT_VERSION, ApiVersion
from devtools_testutils import AzureRecordedTestCase


def get_attestation_token(attestation_uri):
Expand All @@ -48,10 +23,10 @@ def get_attestation_token(attestation_uri):
def get_decorator(only_hsm=False, only_vault=False, api_versions=None, **kwargs):
"""returns a test decorator for test parameterization"""
params = [
param(api_version=p[0], is_hsm=p[1], **kwargs)
pytest.param(p[0],p[1], id=p[0] + ("_mhsm" if p[1] else "_vault" ))
for p in get_test_parameters(only_hsm, only_vault, api_versions=api_versions)
]
return functools.partial(parameterized.expand, params, name_func=suffixed_test_name)
return params


def get_release_policy(attestation_uri, **kwargs):
Expand Down Expand Up @@ -87,78 +62,51 @@ def get_test_parameters(only_hsm=False, only_vault=False, api_versions=None):
return combinations


def suffixed_test_name(testcase_func, param_num, param):
api_version = param.kwargs.get("api_version")
suffix = "mhsm" if param.kwargs.get("is_hsm") else "vault"
return "{}_{}_{}".format(
testcase_func.__name__, parameterized.to_safe_name(api_version), parameterized.to_safe_name(suffix)
)


def is_public_cloud():
return (".microsoftonline.com" in os.getenv('AZURE_AUTHORITY_HOST', ''))


class KeysTestCase(AzureTestCase):
def setUp(self, *args, **kwargs):
class KeysClientPreparer(AzureRecordedTestCase):
def __init__(self, *args, **kwargs):
vault_playback_url = "https://vaultname.vault.azure.net"
hsm_playback_url = "https://managedhsmname.managedhsm.azure.net"
hsm_playback_url = "https://managedhsmvaultname.vault.azure.net"
self.is_logging_enabled = kwargs.pop("logging_enable", True)

if self.is_live:
self.vault_url = os.environ["AZURE_KEYVAULT_URL"]
self._scrub_url(real_url=self.vault_url, playback_url=vault_playback_url)

self.managed_hsm_url = os.environ.get("AZURE_MANAGEDHSM_URL")
self.vault_url = self.vault_url.rstrip("/")
self.managed_hsm_url = os.environ.get("AZURE_MANAGEDHSM_URL", None)
if self.managed_hsm_url:
self._scrub_url(real_url=self.managed_hsm_url, playback_url=hsm_playback_url)
self.managed_hsm_url = self.managed_hsm_url.rstrip("/")
else:
self.vault_url = vault_playback_url
self.managed_hsm_url = hsm_playback_url

self._set_mgmt_settings_real_values()
super(KeysTestCase, self).setUp(*args, **kwargs)

def tearDown(self):
HttpChallengeCache.clear()
assert len(HttpChallengeCache._cache) == 0
super(KeysTestCase, self).tearDown()
def __call__(self, fn):
def _preparer(test_class, api_version, is_hsm, **kwargs):

def create_key_client(self, vault_uri, **kwargs):
if kwargs.pop("is_async", False):
from azure.keyvault.keys.aio import KeyClient

credential = self.get_credential(KeyClient, is_async=True)
else:
from azure.keyvault.keys import KeyClient
#self._skip_if_not_configured(api_version, is_hsm)
if not self.is_logging_enabled:
kwargs.update({"logging_enable": False})
endpoint_url = self.managed_hsm_url if is_hsm else self.vault_url
client = self.create_key_client(endpoint_url, api_version=api_version, **kwargs)

credential = self.get_credential(KeyClient)
return self.create_client_from_credential(KeyClient, credential=credential, vault_url=vault_uri, **kwargs)
with client:
fn(test_class, client, is_hsm=is_hsm, managed_hsm_url = self.managed_hsm_url, vault_url = self.vault_url)
return _preparer


def create_crypto_client(self, key, **kwargs):
if kwargs.pop("is_async", False):
from azure.keyvault.keys.crypto.aio import CryptographyClient

credential = self.get_credential(CryptographyClient, is_async=True)
else:
from azure.keyvault.keys.crypto import CryptographyClient
def create_key_client(self, vault_uri, **kwargs):

from azure.keyvault.keys import KeyClient

credential = self.get_credential(CryptographyClient)
return self.create_client_from_credential(CryptographyClient, credential=credential, key=key, **kwargs)
credential = self.get_credential(KeyClient)

return self.create_client_from_credential(KeyClient, credential=credential, vault_url=vault_uri, **kwargs)

def _get_attestation_uri(self):
playback_uri = "https://fakeattestation.azurewebsites.net"
if self.is_live:
real_uri = os.environ.get("AZURE_KEYVAULT_ATTESTATION_URL")
if real_uri is None:
pytest.skip("No AZURE_KEYVAULT_ATTESTATION_URL environment variable")
self._scrub_url(real_uri, playback_uri)
return real_uri
return playback_uri

def _scrub_url(self, real_url, playback_url):
real = urlparse(real_url)
playback = urlparse(playback_url)
self.scrubber.register_name_pair(real.netloc, playback.netloc)

def _set_mgmt_settings_real_values(self):
if self.is_live:
Expand Down
Loading

0 comments on commit 9c28b76

Please sign in to comment.