Skip to content

Commit

Permalink
Update various credential mocks (#36620)
Browse files Browse the repository at this point in the history
This is a minor change to various credential mocks to keep things
passing if/when we do additional attribute checking in the core
credential policies.

Signed-off-by: Paul Van Eck <[email protected]>
  • Loading branch information
pvaneck authored Jul 26, 2024
1 parent a27ccdb commit 85eb9ee
Show file tree
Hide file tree
Showing 37 changed files with 170 additions and 140 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ def create_mgmt_aio_client(self, client, **kwargs):

credential = DefaultAzureCredential()
else:
credential = Mock(get_token=asyncio.coroutine(lambda _: AccessToken("fake-token", 0)))
credential = Mock(
spec_set=["get_token"], get_token=asyncio.coroutine(lambda _: AccessToken("fake-token", 0))
)
return client(credential=credential, subscription_id=self.settings.SUBSCRIPTION_ID)

def to_list(self, ait):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _convert_datetime_to_utc_int(input):
async def mock_get_token(*_, **__):
return AccessToken("some_token", _convert_datetime_to_utc_int(datetime.now().replace(tzinfo=timezone.utc)))

credential = Mock(get_token=mock_get_token)
credential = Mock(spec_set=["get_token"], get_token=mock_get_token)


@pytest.mark.asyncio
Expand Down Expand Up @@ -155,4 +155,4 @@ def test_get_thread_client():
chat_client = ChatClient("https://endpoint", credential)
chat_thread_client = chat_client.get_chat_thread_client(thread_id)

assert chat_thread_client.thread_id == thread_id
assert chat_thread_client.thread_id == thread_id
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _convert_datetime_to_utc_int(input):
async def mock_get_token(*_, **__):
return AccessToken("some_token", _convert_datetime_to_utc_int(datetime.now().replace(tzinfo=timezone.utc)))

credential = Mock(get_token=mock_get_token)
credential = Mock(spec_set=["get_token"], get_token=mock_get_token)


@pytest.mark.asyncio
Expand Down
2 changes: 1 addition & 1 deletion sdk/compute/azure-mgmt-compute/tests/_aio_testcase.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def create_mgmt_aio_client(self, client, **kwargs):
from azure.identity.aio import DefaultAzureCredential
credential = DefaultAzureCredential()
else:
credential = Mock(get_token=lambda _: self.mock_completed_future(AccessToken("fake-token", 0)))
credential = Mock(spec_set=["get_token"], get_token=lambda _: self.mock_completed_future(AccessToken("fake-token", 0)))
return client(
credential=credential,
subscription_id=self.settings.SUBSCRIPTION_ID
Expand Down
27 changes: 15 additions & 12 deletions sdk/core/azure-core/tests/async_tests/test_authentication_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ async def get_token(*_, **__):
get_token_calls += 1
return expected_token

fake_credential = Mock(get_token=get_token)
fake_credential = Mock(spec_set=["get_token"], get_token=get_token)
policies = [AsyncBearerTokenCredentialPolicy(fake_credential, "scope"), Mock(send=verify_authorization_header)]
pipeline = AsyncPipeline(transport=Mock(), policies=policies)

Expand All @@ -68,7 +68,7 @@ async def verify_request(request):
return expected_response

get_token = get_completed_future(AccessToken("***", 42))
fake_credential = Mock(get_token=lambda *_, **__: get_token)
fake_credential = Mock(spec_set=["get_token"], get_token=lambda *_, **__: get_token)
policies = [AsyncBearerTokenCredentialPolicy(fake_credential, "scope"), Mock(send=verify_request)]
response = await AsyncPipeline(transport=Mock(), policies=policies).run(expected_request)

Expand All @@ -87,7 +87,7 @@ async def verify_request(request):
return expected_response

get_token = get_completed_future(AccessToken("***", 42))
fake_credential = Mock(get_token=lambda *_, **__: get_token)
fake_credential = Mock(spec_set=["get_token"], get_token=lambda *_, **__: get_token)
policies = [AsyncBearerTokenCredentialPolicy(fake_credential, "scope"), Mock(send=verify_request)]
response = await AsyncPipeline(transport=Mock(), policies=policies).run(expected_request)

Expand All @@ -106,7 +106,7 @@ async def get_token(*_, **__):
get_token_calls += 1
return expected_token

credential = Mock(get_token=get_token)
credential = Mock(spec_set=["get_token"], get_token=get_token)
policies = [
AsyncBearerTokenCredentialPolicy(credential, "scope"),
Mock(send=Mock(return_value=get_completed_future(Mock()))),
Expand Down Expand Up @@ -144,7 +144,7 @@ async def assert_option_popped(request, **kwargs):
assert "enforce_https" not in kwargs, "AsyncBearerTokenCredentialPolicy didn't pop the 'enforce_https' option"
return Mock()

credential = Mock(get_token=lambda *_, **__: get_completed_future(AccessToken("***", 42)))
credential = Mock(spec_set=["get_token"], get_token=lambda *_, **__: get_completed_future(AccessToken("***", 42)))
pipeline = AsyncPipeline(
transport=Mock(send=assert_option_popped), policies=[AsyncBearerTokenCredentialPolicy(credential, "scope")]
)
Expand Down Expand Up @@ -175,7 +175,7 @@ def on_request(self, request):
return Mock()

get_token = get_completed_future(AccessToken("***", 42))
credential = Mock(get_token=lambda *_, **__: get_token)
credential = Mock(spec_set=["get_token"], get_token=lambda *_, **__: get_token)
policies = [AsyncBearerTokenCredentialPolicy(credential, "scope"), ContextValidator()]
pipeline = AsyncPipeline(transport=Mock(send=lambda *_, **__: get_completed_future(Mock())), policies=policies)

Expand All @@ -193,7 +193,7 @@ def on_request(self, request):
return Mock()

get_token = get_completed_future(AccessToken("***", 42))
credential = Mock(get_token=lambda *_, **__: get_token)
credential = Mock(spec_set=["get_token"], get_token=lambda *_, **__: get_token)
policies = [AsyncBearerTokenCredentialPolicy(credential, "scope"), ContextValidator()]
pipeline = AsyncPipeline(transport=Mock(send=lambda *_, **__: get_completed_future(Mock())), policies=policies)

Expand All @@ -217,7 +217,10 @@ async def send(self, request):
self.response = await super().send(request)
return self.response

credential = Mock(get_token=Mock(return_value=get_completed_future(AccessToken("***", int(time.time()) + 3600))))
credential = Mock(
spec_set=["get_token"],
get_token=Mock(return_value=get_completed_future(AccessToken("***", int(time.time()) + 3600))),
)
policy = TestPolicy(credential, "scope")
transport = Mock(send=Mock(return_value=get_completed_future(Mock(status_code=200))))

Expand Down Expand Up @@ -300,7 +303,7 @@ async def get_token(*_, **__):
token = AccessToken(auth_headder, 0)
return token

credential = Mock(get_token=get_token)
credential = Mock(spec_set=["get_token"], get_token=get_token)
auth_policy = AsyncBearerTokenCredentialPolicy(credential, expected_scope)
redirect_policy = AsyncRedirectPolicy()
header_clean_up_policy = SensitiveHeaderCleanupPolicy()
Expand Down Expand Up @@ -344,7 +347,7 @@ async def get_token(*_, **__):
token = AccessToken(auth_headder, 0)
return token

credential = Mock(get_token=get_token)
credential = Mock(spec_set=["get_token"], get_token=get_token)
auth_policy = AsyncBearerTokenCredentialPolicy(credential, expected_scope)
redirect_policy = AsyncRedirectPolicy()
header_clean_up_policy = SensitiveHeaderCleanupPolicy()
Expand Down Expand Up @@ -388,7 +391,7 @@ async def get_token(*_, **__):
token = AccessToken(auth_headder, 0)
return token

credential = Mock(get_token=get_token)
credential = Mock(spec_set=["get_token"], get_token=get_token)
auth_policy = AsyncBearerTokenCredentialPolicy(credential, expected_scope)
redirect_policy = AsyncRedirectPolicy()
header_clean_up_policy = SensitiveHeaderCleanupPolicy(disable_redirect_cleanup=True)
Expand Down Expand Up @@ -432,7 +435,7 @@ async def get_token(*_, **__):
token = AccessToken(auth_headder, 0)
return token

credential = Mock(get_token=get_token)
credential = Mock(spec_set=["get_token"], get_token=get_token)
auth_policy = AsyncBearerTokenCredentialPolicy(credential, expected_scope)
redirect_policy = AsyncRedirectPolicy()
header_clean_up_policy = SensitiveHeaderCleanupPolicy(blocked_redirect_headers=["x-ms-authorization-auxiliary"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ async def run(p):
response = await p.run(request, enforce_https=False)
assert isinstance(response.http_response.status_code, int)

fake_credential = Mock(get_token=get_token)
fake_credential = Mock(spec_set=["get_token"], get_token=get_token)
policies = [AsyncBearerTokenCredentialPolicy(fake_credential, "scope")]
async with AsyncPipeline(TrioRequestsTransport(), policies=policies) as pipeline, trio.open_nursery() as nursery:
nursery.start_soon(run, pipeline)
Expand Down
30 changes: 15 additions & 15 deletions sdk/core/azure-core/tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def verify_authorization_header(request):
assert request.http_request.headers["Authorization"] == "Bearer {}".format(expected_token.token)
return Mock()

fake_credential = Mock(get_token=Mock(return_value=expected_token))
fake_credential = Mock(spec_set=["get_token"], get_token=Mock(return_value=expected_token))
policies = [BearerTokenCredentialPolicy(fake_credential, "scope"), Mock(send=verify_authorization_header)]

pipeline = Pipeline(transport=Mock(), policies=policies)
Expand All @@ -63,7 +63,7 @@ def verify_request(request):
def get_token(*_, **__):
return AccessToken("***", 42)

fake_credential = Mock(get_token=get_token)
fake_credential = Mock(spec_set=["get_token"], get_token=get_token)
policies = [BearerTokenCredentialPolicy(fake_credential, "scope"), Mock(send=verify_request)]
response = Pipeline(transport=Mock(), policies=policies).run(expected_request)

Expand All @@ -73,7 +73,7 @@ def get_token(*_, **__):
@pytest.mark.parametrize("http_request", HTTP_REQUESTS)
def test_bearer_policy_token_caching(http_request):
good_for_one_hour = AccessToken("token", time.time() + 3600)
credential = Mock(get_token=Mock(return_value=good_for_one_hour))
credential = Mock(spec_set=["get_token"], get_token=Mock(return_value=good_for_one_hour))
pipeline = Pipeline(transport=Mock(), policies=[BearerTokenCredentialPolicy(credential, "scope")])

pipeline.run(http_request("GET", "https://spam.eggs"))
Expand Down Expand Up @@ -105,7 +105,7 @@ def assert_option_popped(request, **kwargs):
def get_token(*_, **__):
return AccessToken("***", 42)

credential = Mock(get_token=get_token)
credential = Mock(spec_set=["get_token"], get_token=get_token)
pipeline = Pipeline(
transport=Mock(send=assert_option_popped), policies=[BearerTokenCredentialPolicy(credential, "scope")]
)
Expand Down Expand Up @@ -134,7 +134,7 @@ def on_request(self, request):
assert "enforce_https" in request.context, "'enforce_https' is not in the request's context"
return Mock()

credential = Mock(get_token=Mock(return_value=AccessToken("***", 42)))
credential = Mock(spec_set=["get_token"], get_token=Mock(return_value=AccessToken("***", 42)))
policies = [BearerTokenCredentialPolicy(credential, "scope"), ContextValidator()]
pipeline = Pipeline(transport=Mock(), policies=policies)

Expand All @@ -146,7 +146,7 @@ def test_bearer_policy_default_context(http_request):
"""The policy should call get_token with the scopes given at construction, and no keyword arguments, by default"""
expected_scope = "scope"
token = AccessToken("", 0)
credential = Mock(get_token=Mock(return_value=token))
credential = Mock(spec_set=["get_token"], get_token=Mock(return_value=token))
policy = BearerTokenCredentialPolicy(credential, expected_scope)
pipeline = Pipeline(transport=Mock(), policies=[policy])

Expand All @@ -160,7 +160,7 @@ def test_bearer_policy_enable_cae(http_request):
"""The policy should set enable_cae to True in the get_token request if it is set in constructor."""
expected_scope = "scope"
token = AccessToken("", 0)
credential = Mock(get_token=Mock(return_value=token))
credential = Mock(spec_set=["get_token"], get_token=Mock(return_value=token))
policy = BearerTokenCredentialPolicy(credential, expected_scope, enable_cae=True)
pipeline = Pipeline(transport=Mock(), policies=[policy])

Expand All @@ -177,7 +177,7 @@ class ContextValidator(SansIOHTTPPolicy):
def on_request(self, request):
assert not any(request.context), "the policy shouldn't add to the request's context"

credential = Mock(get_token=Mock(return_value=AccessToken("***", 42)))
credential = Mock(spec_set=["get_token"], get_token=Mock(return_value=AccessToken("***", 42)))
policies = [BearerTokenCredentialPolicy(credential, "scope"), ContextValidator()]
pipeline = Pipeline(transport=Mock(), policies=policies)

Expand All @@ -195,7 +195,7 @@ def on_challenge(self, request, challenge):
self.__class__.called = True
return False

credential = Mock(get_token=Mock(return_value=AccessToken("***", int(time.time()) + 3600)))
credential = Mock(spec_set=["get_token"], get_token=Mock(return_value=AccessToken("***", int(time.time()) + 3600)))
policies = [TestPolicy(credential, "scope")]
response = Mock(status_code=401, headers={"WWW-Authenticate": 'Basic realm="localhost"'})
transport = Mock(send=Mock(return_value=response))
Expand All @@ -212,7 +212,7 @@ def test_bearer_policy_cannot_complete_challenge(http_request):

expected_scope = "scope"
expected_token = AccessToken("***", int(time.time()) + 3600)
credential = Mock(get_token=Mock(return_value=expected_token))
credential = Mock(spec_set=["get_token"], get_token=Mock(return_value=expected_token))
expected_response = Mock(status_code=401, headers={"WWW-Authenticate": 'Basic realm="localhost"'})
transport = Mock(send=Mock(return_value=expected_response))
policies = [BearerTokenCredentialPolicy(credential, expected_scope)]
Expand Down Expand Up @@ -241,7 +241,7 @@ def send(self, request):
self.response = super(TestPolicy, self).send(request)
return self.response

credential = Mock(get_token=Mock(return_value=AccessToken("***", int(time.time()) + 3600)))
credential = Mock(spec_set=["get_token"], get_token=Mock(return_value=AccessToken("***", int(time.time()) + 3600)))
policy = TestPolicy(credential, "scope")
transport = Mock(send=Mock(return_value=Mock(status_code=200)))

Expand Down Expand Up @@ -495,7 +495,7 @@ def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> PipelineRe
auth_headder = "token"
expected_scope = "scope"
token = AccessToken(auth_headder, 0)
credential = Mock(get_token=Mock(return_value=token))
credential = Mock(spec_set=["get_token"], get_token=Mock(return_value=token))
auth_policy = BearerTokenCredentialPolicy(credential, expected_scope)
redirect_policy = RedirectPolicy()
header_clean_up_policy = SensitiveHeaderCleanupPolicy()
Expand Down Expand Up @@ -534,7 +534,7 @@ def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> PipelineRe
auth_headder = "token"
expected_scope = "scope"
token = AccessToken(auth_headder, 0)
credential = Mock(get_token=Mock(return_value=token))
credential = Mock(spec_set=["get_token"], get_token=Mock(return_value=token))
auth_policy = BearerTokenCredentialPolicy(credential, expected_scope)
redirect_policy = RedirectPolicy()
header_clean_up_policy = SensitiveHeaderCleanupPolicy()
Expand Down Expand Up @@ -573,7 +573,7 @@ def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> PipelineRe
auth_headder = "token"
expected_scope = "scope"
token = AccessToken(auth_headder, 0)
credential = Mock(get_token=Mock(return_value=token))
credential = Mock(spec_set=["get_token"], get_token=Mock(return_value=token))
auth_policy = BearerTokenCredentialPolicy(credential, expected_scope)
redirect_policy = RedirectPolicy()
header_clean_up_policy = SensitiveHeaderCleanupPolicy(disable_redirect_cleanup=True)
Expand Down Expand Up @@ -612,7 +612,7 @@ def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> PipelineRe
auth_headder = "token"
expected_scope = "scope"
token = AccessToken(auth_headder, 0)
credential = Mock(get_token=Mock(return_value=token))
credential = Mock(spec_set=["get_token"], get_token=Mock(return_value=token))
auth_policy = BearerTokenCredentialPolicy(credential, expected_scope)
redirect_policy = RedirectPolicy()
header_clean_up_policy = SensitiveHeaderCleanupPolicy(blocked_redirect_headers=["x-ms-authorization-auxiliary"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ async def get_token(*scopes, **kwargs):
assert scopes == (expected_scope,)
return next(tokens)

credential = Mock(get_token=Mock(wraps=get_token))
credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token))
transport = Mock(send=Mock(wraps=send))
policies = [AsyncARMChallengeAuthenticationPolicy(credential, expected_scope)]
pipeline = AsyncPipeline(transport=transport, policies=policies)
Expand Down Expand Up @@ -113,7 +113,7 @@ async def get_token(*_, **__):
return AccessToken("***", 42)

transport = Mock(send=Mock(wraps=send))
credential = Mock(get_token=Mock(wraps=get_token))
credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token))
policies = [AsyncARMChallengeAuthenticationPolicy(credential, "scope")]
pipeline = AsyncPipeline(transport=transport, policies=policies)

Expand Down Expand Up @@ -151,8 +151,8 @@ async def get_token2(_):
get_token_calls2 += 1
return second_token

fake_credential1 = Mock(get_token=get_token1)
fake_credential2 = Mock(get_token=get_token2)
fake_credential1 = Mock(spec_set=["get_token"], get_token=get_token1)
fake_credential2 = Mock(spec_set=["get_token"], get_token=get_token2)
policies = [
AsyncAuxiliaryAuthenticationPolicy([fake_credential1, fake_credential2], "scope"),
Mock(send=verify_authorization_header),
Expand Down
6 changes: 3 additions & 3 deletions sdk/core/azure-mgmt-core/tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def verify_authorization_header(request):
)
return Mock()

fake_credential1 = Mock(get_token=Mock(return_value=first_token))
fake_credential2 = Mock(get_token=Mock(return_value=second_token))
fake_credential1 = Mock(spec_set=["get_token"], get_token=Mock(return_value=first_token))
fake_credential2 = Mock(spec_set=["get_token"], get_token=Mock(return_value=second_token))
policies = [
AuxiliaryAuthenticationPolicy([fake_credential1, fake_credential2], "scope"),
Mock(send=verify_authorization_header),
Expand Down Expand Up @@ -136,7 +136,7 @@ def get_token(*scopes, **kwargs):
assert scopes == (expected_scope,)
return next(tokens)

credential = Mock(get_token=Mock(wraps=get_token))
credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token))
transport = Mock(send=Mock(wraps=send))
policies = [ARMChallengeAuthenticationPolicy(credential, expected_scope)]
pipeline = Pipeline(transport=transport, policies=policies)
Expand Down
Loading

0 comments on commit 85eb9ee

Please sign in to comment.