Skip to content

Commit

Permalink
Added custom_endpoint as an optional parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
Craig Treasure committed Feb 8, 2021
1 parent 2d13e26 commit a2e77f4
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,12 @@ class MixedRealityStsClient(object):
The Mixed Reality service account domain.
:param Union[TokenCredential, AzureKeyCredential] credential:
The credential used to access the Mixed Reality service.
:param str custom_endpoint:
Override the Mixed Reality STS service endpoint.
"""

def __init__(self, account_id, account_domain, credential, **kwargs):
# type: (str, str, Union[TokenCredential, AzureKeyCredential], Any) -> None
def __init__(self, account_id, account_domain, credential, custom_endpoint=None, **kwargs):
# type: (str, str, Union[TokenCredential, AzureKeyCredential], str, Any) -> None
if not account_id:
raise ValueError("account_id can not be None")

Expand All @@ -57,13 +59,10 @@ def __init__(self, account_id, account_domain, credential, **kwargs):

self._credential = credential

endpoint_url = kwargs.pop('endpoint_url', construct_endpoint_url(account_domain))

try:
if not endpoint_url.lower().startswith('http'):
endpoint_url = "https://" + endpoint_url
except AttributeError:
raise ValueError("Host URL must be a string")
if custom_endpoint:
endpoint_url = custom_endpoint
else:
endpoint_url = construct_endpoint_url(account_domain)

parsed_url = urlparse(endpoint_url.rstrip('/'))
if not parsed_url.netloc:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from typing import TYPE_CHECKING
from typing import Optional, TYPE_CHECKING

try:
from urllib.parse import urlparse
Expand Down Expand Up @@ -38,12 +38,16 @@ class MixedRealityStsClient(object):
The Mixed Reality service account domain.
:param Union[TokenCredential, AzureKeyCredential] credential:
The credential used to access the Mixed Reality service.
:param str custom_endpoint:
Override the Mixed Reality STS service endpoint.
"""

def __init__(self,
account_id: str,
account_domain: str,
credential: Union[AzureKeyCredential, "AsyncTokenCredential"],
*,
custom_endpoint: Optional[str] = None,
**kwargs) -> None:
if not account_id:
raise ValueError("account_id can not be None")
Expand All @@ -62,13 +66,10 @@ def __init__(self,

self._credential = credential

endpoint_url = kwargs.pop('endpoint_url', construct_endpoint_url(account_domain))

try:
if not endpoint_url.lower().startswith('http'):
endpoint_url = "https://" + endpoint_url
except AttributeError:
raise ValueError("Host URL must be a string")
if custom_endpoint:
endpoint_url = custom_endpoint
else:
endpoint_url = construct_endpoint_url(account_domain)

parsed_url = urlparse(endpoint_url.rstrip('/'))
if not parsed_url.netloc:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_create_client_custom_with_endpoint(self):
account_id=self.account_id,
account_domain=self.account_domain,
credential=self.key_credential,
endpoint_url=custom_endpoint_url)
custom_endpoint=custom_endpoint_url)

assert client._endpoint_url == custom_endpoint_url

Expand Down Expand Up @@ -91,7 +91,7 @@ def test_create_client_with_invalid_arguments(self):
account_id=self.account_id,
account_domain=self.account_domain,
credential=self.key_credential,
endpoint_url="#")
custom_endpoint="#")

def test_get_token(self):
client = MixedRealityStsClient(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_create_client_custom_with_endpoint(self):
account_id=self.account_id,
account_domain=self.account_domain,
credential=self.key_credential,
endpoint_url=custom_endpoint_url)
custom_endpoint=custom_endpoint_url)

assert client._endpoint_url == custom_endpoint_url

Expand Down Expand Up @@ -92,7 +92,7 @@ def test_create_client_with_invalid_arguments(self):
account_id=self.account_id,
account_domain=self.account_domain,
credential=self.key_credential,
endpoint_url="#")
custom_endpoint="#")

@AzureTestCase.await_prepared_test
async def test_get_token(self):
Expand Down

0 comments on commit a2e77f4

Please sign in to comment.