Skip to content

Commit

Permalink
[Tables] changes to not allow strings for credentials (#19095)
Browse files Browse the repository at this point in the history
* changes to not allow strings for credentials

* annas comments, plus skipping test that is not raising value error anymore

* mypy and lint fixes

* fixing samples
  • Loading branch information
seankane-msft authored Jun 4, 2021
1 parent a46f8da commit 0e79a8c
Show file tree
Hide file tree
Showing 10 changed files with 79 additions and 95 deletions.
34 changes: 11 additions & 23 deletions sdk/tables/azure-data-tables/azure/data/tables/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from urlparse import parse_qs, urlparse # type: ignore
from urllib2 import quote # type: ignore

import six
from azure.core.credentials import AzureSasCredential, AzureNamedKeyCredential
from azure.core.utils import parse_connection_string
from azure.core.pipeline.transport import (
Expand Down Expand Up @@ -72,7 +71,7 @@ class AccountHostsMixin(object): # pylint: disable=too-many-instance-attributes
def __init__(
self,
account_url, # type: Any
credential=None, # type: Optional[Any]
credential=None, # type: Optional[Union[AzureNamedKeyCredential, AzureSasCredential]]
**kwargs # type: Any
):
# type: (...) -> None
Expand All @@ -88,7 +87,7 @@ def __init__(
_, sas_token = parse_query(parsed_url.query)
if not sas_token and not credential:
raise ValueError(
"You need to provide either a SAS token or an account shared key to authenticate."
"You need to provide either an AzureSasCredential or AzureNamedKeyCredential"
)
self._query_str, credential = format_query_string(sas_token, credential)
self._location_mode = kwargs.get("location_mode", LocationMode.PRIMARY)
Expand All @@ -115,9 +114,9 @@ def __init__(
if self.scheme.lower() != "https" and hasattr(self.credential, "get_token"):
raise ValueError("Token credential is only supported with HTTPS.")
if hasattr(self.credential, "named_key"):
self.account_name = self.credential.named_key.name
self.account_name = self.credential.named_key.name # type: ignore
secondary_hostname = "{}-secondary.table.{}".format(
self.credential.named_key.name, SERVICE_HOST_BASE
self.credential.named_key.name, SERVICE_HOST_BASE # type: ignore
)

if not self._hosts:
Expand Down Expand Up @@ -345,9 +344,10 @@ def parse_connection_str(conn_str, credential, keyword_args):
try:
credential = AzureNamedKeyCredential(name=conn_settings["accountname"], key=conn_settings["accountkey"])
except KeyError:
credential = conn_settings.get("sharedaccesssignature")
# if "sharedaccesssignature" in conn_settings:
# credential = AzureSasCredential(conn_settings['sharedaccesssignature'])
credential = conn_settings.get("sharedaccesssignature", None)
if not credential:
raise ValueError("Connection string missing required connection details.")
credential = AzureSasCredential(credential)
primary = conn_settings.get("tableendpoint")
secondary = conn_settings.get("tablesecondaryendpoint")
if not primary:
Expand Down Expand Up @@ -394,10 +394,9 @@ def format_query_string(sas_token, credential):
"You cannot use AzureSasCredential when the resource URI also contains a Shared Access Signature.")
if sas_token and not credential:
query_str += sas_token
elif is_credential_sastoken(credential):
query_str += credential.lstrip("?")
credential = None
return query_str.rstrip("?&"), credential
elif isinstance(credential, (AzureSasCredential, AzureNamedKeyCredential)):
return "", credential
return query_str.rstrip("?&"), None


def parse_query(query_str):
Expand All @@ -414,14 +413,3 @@ def parse_query(query_str):

snapshot = parsed_query.get("snapshot") or parsed_query.get("sharesnapshot")
return snapshot, sas_token


def is_credential_sastoken(credential):
if not credential or not isinstance(credential, six.string_types):
return False

sas_values = QueryStringConstants.to_list()
parsed_query = parse_qs(credential.lstrip("?"))
if parsed_query and all([k in sas_values for k in parsed_query.keys()]):
return True
return False
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ async def authentication_by_shared_access_signature(self):
# Instantiate a TableServiceClient using a connection string
# [START auth_by_sas]
from azure.data.tables.aio import TableServiceClient
from azure.core.credentials import AzureNamedKeyCredential
from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential

# Create a SAS token to use for authentication of a client
from azure.data.tables import generate_account_sas, ResourceTypes, AccountSasPermissions
Expand All @@ -82,7 +82,7 @@ async def authentication_by_shared_access_signature(self):
expiry=datetime.utcnow() + timedelta(hours=1),
)

async with TableServiceClient(endpoint=self.endpoint, credential=sas_token) as token_auth_table_service:
async with TableServiceClient(endpoint=self.endpoint, credential=AzureSasCredential(sas_token)) as token_auth_table_service:
properties = await token_auth_table_service.get_service_properties()
print("Shared Access Signature: {}".format(properties))
# [END auth_by_sas]
Expand Down
4 changes: 2 additions & 2 deletions sdk/tables/azure-data-tables/samples/sample_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def authentication_by_shared_access_signature(self):

# [START auth_from_sas]
from azure.data.tables import TableServiceClient
from azure.core.credentials import AzureNamedKeyCredential
from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential

# Create a SAS token to use for authentication of a client
from azure.data.tables import generate_account_sas, ResourceTypes, AccountSasPermissions
Expand All @@ -84,7 +84,7 @@ def authentication_by_shared_access_signature(self):
expiry=datetime.utcnow() + timedelta(hours=1),
)

with TableServiceClient(endpoint=self.endpoint, credential=sas_token) as token_auth_table_service:
with TableServiceClient(endpoint=self.endpoint, credential=AzureSasCredential(sas_token)) as token_auth_table_service:
properties = token_auth_table_service.get_service_properties()
print("Shared Access Signature: {}".format(properties))
# [END auth_from_sas]
Expand Down
6 changes: 3 additions & 3 deletions sdk/tables/azure-data-tables/tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
generate_account_sas,
ResourceTypes
)
from azure.core.credentials import AzureNamedKeyCredential
from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential
from azure.core.exceptions import ResourceExistsError

from _shared.testcase import TableTestCase, TEST_TABLE_PREFIX
Expand Down Expand Up @@ -386,14 +386,14 @@ def test_account_sas(self, tables_storage_account_name, tables_primary_storage_a
entity['RowKey'] = u'test2'
table.upsert_entity(mode=UpdateMode.MERGE, entity=entity)

token = self.generate_sas(
token = AzureSasCredential(self.generate_sas(
generate_account_sas,
tables_primary_storage_account_key,
resource_types=ResourceTypes(object=True),
permission=AccountSasPermissions(read=True),
expiry=datetime.utcnow() + timedelta(hours=1),
start=datetime.utcnow() - timedelta(minutes=1),
)
))

account_url = self.account_url(tables_storage_account_name, "table")

Expand Down
4 changes: 3 additions & 1 deletion sdk/tables/azure-data-tables/tests/test_table_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from devtools_testutils import AzureTestCase

from azure.core.credentials import AzureNamedKeyCredential
from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential
from azure.core.exceptions import ResourceExistsError
from azure.data.tables import (
TableAccessPolicy,
Expand Down Expand Up @@ -336,6 +336,8 @@ async def test_account_sas(self, tables_storage_account_name, tables_primary_sto
start=datetime.utcnow() - timedelta(minutes=1),
)

token = AzureSasCredential(token)

account_url = self.account_url(tables_storage_account_name, "table")

service = self.create_client_from_credential(TableServiceClient, token, endpoint=account_url)
Expand Down
33 changes: 14 additions & 19 deletions sdk/tables/azure-data-tables/tests/test_table_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from azure.data.tables import TableServiceClient, TableClient
from azure.data.tables import __version__ as VERSION
from azure.core.credentials import AzureNamedKeyCredential
from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential

from _shared.testcase import (
TableTestCase
Expand Down Expand Up @@ -144,7 +144,7 @@ def test_create_service_with_sas(self):
# Arrange
url = self.account_url(self.tables_storage_account_name, "table")
suffix = '.table.core.windows.net'
token = self.generate_sas_token()
token = AzureSasCredential(self.generate_sas_token())
for service_type in SERVICES:
# Act
service = service_type(
Expand All @@ -154,8 +154,7 @@ def test_create_service_with_sas(self):
assert service is not None
assert service.account_name == self.tables_storage_account_name
assert service.url.startswith('https://' + self.tables_storage_account_name + suffix)
assert service.url.endswith(token)
assert service.credential is None
assert isinstance(service.credential, AzureSasCredential)

def test_create_service_china(self):
# Arrange
Expand Down Expand Up @@ -198,7 +197,7 @@ def test_create_service_empty_key(self):
with pytest.raises(ValueError):
test_service = service_type(endpoint=123456, credential=self.credential, table_name='foo')

assert str(e.value) == "You need to provide either a SAS token or an account shared key to authenticate."
assert str(e.value) == "You need to provide either an AzureSasCredential or AzureNamedKeyCredential"

def test_create_service_with_socket_timeout(self):
# Arrange
Expand Down Expand Up @@ -242,9 +241,8 @@ def test_create_service_with_connection_string_key(self):

def test_create_service_with_connection_string_sas(self):
# Arrange
token = self.generate_sas_token()
conn_string = 'AccountName={};SharedAccessSignature={};'.format(
self.tables_storage_account_name, token)
token = AzureSasCredential(self.generate_sas_token())
conn_string = 'AccountName={};SharedAccessSignature={};'.format(self.tables_storage_account_name, token.signature)

for service_type in SERVICES:
# Act
Expand All @@ -253,10 +251,8 @@ def test_create_service_with_connection_string_sas(self):
# Assert
assert service is not None
assert service.account_name == self.tables_storage_account_name
assert service.url.startswith(
'https://' + self.tables_storage_account_name + '.table.core.windows.net')
assert service.url.endswith(token)
assert service.credential is None
assert service.url.startswith('https://' + self.tables_storage_account_name + '.table.core.windows.net')
assert isinstance(service.credential , AzureSasCredential)

def test_create_service_with_connection_string_cosmos(self):
# Arrange
Expand Down Expand Up @@ -380,8 +376,8 @@ def test_create_service_with_conn_str_succeeds_if_sec_with_primary(self):
assert service._primary_endpoint.startswith('https://www.mydomain.com')

def test_create_service_with_custom_account_endpoint_path(self):
token = self.generate_sas_token()
custom_account_url = "http://local-machine:11002/custom/account/path/" + token
token = AzureSasCredential(self.generate_sas_token())
custom_account_url = "http://local-machine:11002/custom/account/path/" + token.signature
for service_type in SERVICES.items():
conn_string = 'DefaultEndpointsProtocol=http;AccountName={};AccountKey={};TableEndpoint={};'.format(
self.tables_storage_account_name, self.tables_primary_storage_account_key, custom_account_url)
Expand All @@ -408,7 +404,7 @@ def test_create_service_with_custom_account_endpoint_path(self):
assert service._primary_hostname == 'local-machine:11002/custom/account/path'
assert service.url.startswith('http://local-machine:11002/custom/account/path')

service = TableClient.from_table_url("http://local-machine:11002/custom/account/path/foo" + token)
service = TableClient.from_table_url("http://local-machine:11002/custom/account/path/foo" + token.signature)
assert service.account_name == "custom"
assert service.table_name == "foo"
assert service.credential == None
Expand Down Expand Up @@ -480,17 +476,18 @@ def test_closing_pipeline_client_simple(self):
self.account_url(self.tables_storage_account_name, "table"), credential=self.credential, table_name='table')
service.close()

@pytest.mark.skip("HTTP prefix does not raise an error")
def test_create_service_with_token_and_http(self):
for service_type in SERVICES:

with pytest.raises(ValueError):
url = self.account_url(self.tables_storage_account_name, "table").replace('https', 'http')
service_type(url, credential=self.generate_fake_token(), table_name='foo')
service_type(url, credential=AzureSasCredential("fake_sas_credential"), table_name='foo')

def test_create_service_with_token(self):
url = self.account_url(self.tables_storage_account_name, "table")
suffix = '.table.core.windows.net'
self.token_credential = self.generate_fake_token()
self.token_credential = AzureSasCredential("fake_sas_credential")

service = TableClient(url, credential=self.token_credential, table_name='foo')

Expand All @@ -500,7 +497,6 @@ def test_create_service_with_token(self):
assert service.url.startswith('https://' + self.tables_storage_account_name + suffix)
assert service.credential == self.token_credential
assert not hasattr(service.credential, 'account_key')
assert hasattr(service.credential, 'get_token')

service = TableServiceClient(url, credential=self.token_credential, table_name='foo')

Expand All @@ -510,7 +506,6 @@ def test_create_service_with_token(self):
assert service.url.startswith('https://' + self.tables_storage_account_name + suffix)
assert service.credential == self.token_credential
assert not hasattr(service.credential, 'account_key')
assert hasattr(service.credential, 'get_token')

def test_create_client_with_api_version(self):
url = self.account_url(self.tables_storage_account_name, "table")
Expand Down
Loading

0 comments on commit 0e79a8c

Please sign in to comment.