From 0d70e9c5617eaea6071316f6e5d161ff2df37da5 Mon Sep 17 00:00:00 2001 From: Danny Hermes Date: Sat, 31 Jan 2015 16:50:09 -0800 Subject: [PATCH] Move generate_signed_url to standalone function in credentials. This was **very** out of place as a method on Connection in storage. --- gcloud/credentials.py | 170 +++++++++++++++++ gcloud/storage/blob.py | 12 +- gcloud/storage/connection.py | 162 ---------------- gcloud/storage/test_blob.py | 77 ++++++-- gcloud/storage/test_connection.py | 298 ----------------------------- gcloud/test_credentials.py | 301 ++++++++++++++++++++++++++++++ 6 files changed, 541 insertions(+), 479 deletions(-) diff --git a/gcloud/credentials.py b/gcloud/credentials.py index 58f2daeae02f..16b5b159b4e9 100644 --- a/gcloud/credentials.py +++ b/gcloud/credentials.py @@ -14,7 +14,19 @@ """A simple wrapper around the OAuth2 credentials library.""" +import base64 +import calendar +import datetime +import urllib +import six + +from Crypto.Hash import SHA256 +from Crypto.PublicKey import RSA +from Crypto.Signature import PKCS1_v1_5 from oauth2client import client +from oauth2client import crypt +from oauth2client import service_account +import pytz def get_credentials(): @@ -91,3 +103,161 @@ def get_for_service_account_p12(client_email, private_key_path, scope=None): service_account_name=client_email, private_key=open(private_key_path, 'rb').read(), scope=scope) + + +def _get_pem_key(credentials): + """Gets RSA key for a PEM payload from a credentials object. + + :type credentials: :class:`client.SignedJwtAssertionCredentials`, + :class:`service_account._ServiceAccountCredentials` + :param credentials: The credentials used to create an RSA key + for signing text. + + :rtype: :class:`Crypto.PublicKey.RSA._RSAobj` + :returns: An RSA object used to sign text. + :raises: `TypeError` if `credentials` is the wrong type. + """ + if isinstance(credentials, client.SignedJwtAssertionCredentials): + # Take our PKCS12 (.p12) key and make it into a RSA key we can use. + pem_text = crypt.pkcs12_key_as_pem(credentials.private_key, + credentials.private_key_password) + elif isinstance(credentials, service_account._ServiceAccountCredentials): + pem_text = credentials._private_key_pkcs8_text + else: + raise TypeError((credentials, + 'not a valid service account credentials type')) + + return RSA.importKey(pem_text) + + +def _get_signed_query_params(credentials, expiration, signature_string): + """Gets query parameters for creating a signed URL. + + :type credentials: :class:`client.SignedJwtAssertionCredentials`, + :class:`service_account._ServiceAccountCredentials` + :param credentials: The credentials used to create an RSA key + for signing text. + + :type expiration: int or long + :param expiration: When the signed URL should expire. + + :type signature_string: string + :param signature_string: The string to be signed by the credentials. + + :rtype: dict + :returns: Query parameters matching the signing credentials with a + signed payload. + """ + pem_key = _get_pem_key(credentials) + # Sign the string with the RSA key. + signer = PKCS1_v1_5.new(pem_key) + signature_hash = SHA256.new(signature_string) + signature_bytes = signer.sign(signature_hash) + signature = base64.b64encode(signature_bytes) + + if isinstance(credentials, client.SignedJwtAssertionCredentials): + service_account_name = credentials.service_account_name + elif isinstance(credentials, service_account._ServiceAccountCredentials): + service_account_name = credentials._service_account_email + # We know one of the above must occur since `_get_pem_key` fails if not. + return { + 'GoogleAccessId': service_account_name, + 'Expires': str(expiration), + 'Signature': signature, + } + + +def _utcnow(): # pragma: NO COVER testing replaces + """Returns current time as UTC datetime. + + NOTE: on the module namespace so tests can replace it. + """ + return datetime.datetime.utcnow() + + +def _get_expiration_seconds(expiration): + """Convert 'expiration' to a number of seconds in the future. + + :type expiration: int, long, datetime.datetime, datetime.timedelta + :param expiration: When the signed URL should expire. + + :rtype: int + :returns: a timestamp as an absolute number of seconds. + """ + # If it's a timedelta, add it to `now` in UTC. + if isinstance(expiration, datetime.timedelta): + now = _utcnow().replace(tzinfo=pytz.utc) + expiration = now + expiration + + # If it's a datetime, convert to a timestamp. + if isinstance(expiration, datetime.datetime): + # Make sure the timezone on the value is UTC + # (either by converting or replacing the value). + if expiration.tzinfo: + expiration = expiration.astimezone(pytz.utc) + else: + expiration = expiration.replace(tzinfo=pytz.utc) + + # Turn the datetime into a timestamp (seconds, not microseconds). + expiration = int(calendar.timegm(expiration.timetuple())) + + if not isinstance(expiration, six.integer_types): + raise TypeError('Expected an integer timestamp, datetime, or ' + 'timedelta. Got %s' % type(expiration)) + return expiration + + +def generate_signed_url(credentials, resource, expiration, + api_access_endpoint='', + method='GET', content_md5=None, + content_type=None): + """Generate signed URL to provide query-string auth'n to a resource. + + :type credentials: :class:`oauth2client.appengine.AppAssertionCredentials` + :param credentials: Credentials object with an associated private key to + sign text. + + :type resource: string + :param resource: A pointer to a specific resource + (typically, ``/bucket-name/path/to/blob.txt``). + + :type expiration: int, long, datetime.datetime, datetime.timedelta + :param expiration: When the signed URL should expire. + + :type api_access_endpoint: string + :param api_access_endpoint: Optional URI base. Defaults to empty string. + + :type method: string + :param method: The HTTP verb that will be used when requesting the URL. + + :type content_md5: string + :param content_md5: The MD5 hash of the object referenced by + ``resource``. + + :type content_type: string + :param content_type: The content type of the object referenced by + ``resource``. + + :rtype: string + :returns: A signed URL you can use to access the resource + until expiration. + """ + expiration = _get_expiration_seconds(expiration) + + # Generate the string to sign. + signature_string = '\n'.join([ + method, + content_md5 or '', + content_type or '', + str(expiration), + resource]) + + # Set the right query parameters. + query_params = _get_signed_query_params(credentials, + expiration, + signature_string) + + # Return the built URL. + return '{endpoint}{resource}?{querystring}'.format( + endpoint=api_access_endpoint, resource=resource, + querystring=urllib.urlencode(query_params)) diff --git a/gcloud/storage/blob.py b/gcloud/storage/blob.py index 44f35056ae3c..31927bed640d 100644 --- a/gcloud/storage/blob.py +++ b/gcloud/storage/blob.py @@ -25,11 +25,15 @@ from _gcloud_vendor.apitools.base.py import http_wrapper from _gcloud_vendor.apitools.base.py import transfer +from gcloud.credentials import generate_signed_url from gcloud.storage._helpers import _PropertyMixin from gcloud.storage._helpers import _scalar_property from gcloud.storage.acl import ObjectACL +_API_ACCESS_ENDPOINT = 'https://storage.googleapis.com' + + class Blob(_PropertyMixin): """A wrapper around Cloud Storage's concept of an ``Object``.""" @@ -157,9 +161,11 @@ def generate_signed_url(self, expiration, method='GET'): resource = '/{bucket_name}/{quoted_name}'.format( bucket_name=self.bucket.name, quoted_name=urllib.quote(self.name, safe='')) - return self.connection.generate_signed_url(resource=resource, - expiration=expiration, - method=method) + + return generate_signed_url( + self.connection.credentials, resource=resource, + api_access_endpoint=_API_ACCESS_ENDPOINT, + expiration=expiration, method=method) def exists(self): """Determines whether or not this blob exists. diff --git a/gcloud/storage/connection.py b/gcloud/storage/connection.py index f473e5ebca43..e9b5ecfae8a5 100644 --- a/gcloud/storage/connection.py +++ b/gcloud/storage/connection.py @@ -14,20 +14,9 @@ """Create / interact with gcloud storage connections.""" -import base64 -import calendar -import datetime import json import urllib -from Crypto.Hash import SHA256 -from Crypto.PublicKey import RSA -from Crypto.Signature import PKCS1_v1_5 -from oauth2client import client -from oauth2client import crypt -from oauth2client import service_account -import pytz - from gcloud.connection import Connection as _Base from gcloud.exceptions import make_exception from gcloud.exceptions import NotFound @@ -36,76 +25,6 @@ import six -def _utcnow(): # pragma: NO COVER testing replaces - """Returns current time as UTC datetime. - - NOTE: on the module namespace so tests can replace it. - """ - return datetime.datetime.utcnow() - - -def _get_pem_key(credentials): - """Gets RSA key for a PEM payload from a credentials object. - - :type credentials: :class:`client.SignedJwtAssertionCredentials`, - :class:`service_account._ServiceAccountCredentials` - :param credentials: The credentials used to create an RSA key - for signing text. - - :rtype: :class:`Crypto.PublicKey.RSA._RSAobj` - :returns: An RSA object used to sign text. - :raises: `TypeError` if `credentials` is the wrong type. - """ - if isinstance(credentials, client.SignedJwtAssertionCredentials): - # Take our PKCS12 (.p12) key and make it into a RSA key we can use. - pem_text = crypt.pkcs12_key_as_pem(credentials.private_key, - credentials.private_key_password) - elif isinstance(credentials, service_account._ServiceAccountCredentials): - pem_text = credentials._private_key_pkcs8_text - else: - raise TypeError((credentials, - 'not a valid service account credentials type')) - - return RSA.importKey(pem_text) - - -def _get_signed_query_params(credentials, expiration, signature_string): - """Gets query parameters for creating a signed URL. - - :type credentials: :class:`client.SignedJwtAssertionCredentials`, - :class:`service_account._ServiceAccountCredentials` - :param credentials: The credentials used to create an RSA key - for signing text. - - :type expiration: int or long - :param expiration: When the signed URL should expire. - - :type signature_string: string - :param signature_string: The string to be signed by the credentials. - - :rtype: dict - :returns: Query parameters matching the signing credentials with a - signed payload. - """ - pem_key = _get_pem_key(credentials) - # Sign the string with the RSA key. - signer = PKCS1_v1_5.new(pem_key) - signature_hash = SHA256.new(signature_string) - signature_bytes = signer.sign(signature_hash) - signature = base64.b64encode(signature_bytes) - - if isinstance(credentials, client.SignedJwtAssertionCredentials): - service_account_name = credentials.service_account_name - elif isinstance(credentials, service_account._ServiceAccountCredentials): - service_account_name = credentials._service_account_email - # We know one of the above must occur since `_get_pem_key` fails if not. - return { - 'GoogleAccessId': service_account_name, - 'Expires': str(expiration), - 'Signature': signature, - } - - class Connection(_Base): """A connection to Google Cloud Storage via the JSON REST API. @@ -155,8 +74,6 @@ class Connection(_Base): API_URL_TEMPLATE = '{api_base_url}/storage/{api_version}{path}' """A template for the URL of a particular API call.""" - API_ACCESS_ENDPOINT = 'https://storage.googleapis.com' - def __init__(self, project, *args, **kwargs): """:type project: string @@ -507,53 +424,6 @@ def new_bucket(self, bucket): raise TypeError('Invalid bucket: %s' % bucket) - def generate_signed_url(self, resource, expiration, - method='GET', content_md5=None, - content_type=None): - """Generate signed URL to provide query-string auth'n to a resource. - - :type resource: string - :param resource: A pointer to a specific resource - (typically, ``/bucket-name/path/to/blob.txt``). - - :type expiration: int, long, datetime.datetime, datetime.timedelta - :param expiration: When the signed URL should expire. - - :type method: string - :param method: The HTTP verb that will be used when requesting the URL. - - :type content_md5: string - :param content_md5: The MD5 hash of the object referenced by - ``resource``. - - :type content_type: string - :param content_type: The content type of the object referenced by - ``resource``. - - :rtype: string - :returns: A signed URL you can use to access the resource - until expiration. - """ - expiration = _get_expiration_seconds(expiration) - - # Generate the string to sign. - signature_string = '\n'.join([ - method, - content_md5 or '', - content_type or '', - str(expiration), - resource]) - - # Set the right query parameters. - query_params = _get_signed_query_params(self.credentials, - expiration, - signature_string) - - # Return the built URL. - return '{endpoint}{resource}?{querystring}'.format( - endpoint=self.API_ACCESS_ENDPOINT, resource=resource, - querystring=urllib.urlencode(query_params)) - class _BucketIterator(Iterator): """An iterator listing all buckets. @@ -577,35 +447,3 @@ def get_items_from_response(self, response): """ for item in response.get('items', []): yield Bucket(properties=item, connection=self.connection) - - -def _get_expiration_seconds(expiration): - """Convert 'expiration' to a number of seconds in the future. - - :type expiration: int, long, datetime.datetime, datetime.timedelta - :param expiration: When the signed URL should expire. - - :rtype: int - :returns: a timestamp as an absolute number of seconds. - """ - # If it's a timedelta, add it to `now` in UTC. - if isinstance(expiration, datetime.timedelta): - now = _utcnow().replace(tzinfo=pytz.utc) - expiration = now + expiration - - # If it's a datetime, convert to a timestamp. - if isinstance(expiration, datetime.datetime): - # Make sure the timezone on the value is UTC - # (either by converting or replacing the value). - if expiration.tzinfo: - expiration = expiration.astimezone(pytz.utc) - else: - expiration = expiration.replace(tzinfo=pytz.utc) - - # Turn the datetime into a timestamp (seconds, not microseconds). - expiration = int(calendar.timegm(expiration.timetuple())) - - if not isinstance(expiration, six.integer_types): - raise TypeError('Expected an integer timestamp, datetime, or ' - 'timedelta. Got %s' % type(expiration)) - return expiration diff --git a/gcloud/storage/test_blob.py b/gcloud/storage/test_blob.py index bbb7147bebd5..69e676893356 100644 --- a/gcloud/storage/test_blob.py +++ b/gcloud/storage/test_blob.py @@ -113,6 +113,9 @@ def test_public_url_w_slash_in_name(self): 'http://commondatastorage.googleapis.com/name/parent%2Fchild') def test_generate_signed_url_w_default_method(self): + from gcloud._testing import _Monkey + from gcloud.storage import blob as MUT + BLOB_NAME = 'blob-name' EXPIRATION = '2014-10-16T20:34:37Z' connection = _Connection() @@ -120,12 +123,25 @@ def test_generate_signed_url_w_default_method(self): blob = self._makeOne(bucket, BLOB_NAME) URI = ('http://example.com/abucket/a-blob-name?Signature=DEADBEEF' '&Expiration=2014-10-16T20:34:37Z') - self.assertEqual(blob.generate_signed_url(EXPIRATION), URI) + + SIGNER = _Signer() + with _Monkey(MUT, generate_signed_url=SIGNER): + self.assertEqual(blob.generate_signed_url(EXPIRATION), URI) + PATH = '/name/%s' % (BLOB_NAME,) - self.assertEqual(connection._signed, - [(PATH, EXPIRATION, {'method': 'GET'})]) + EXPECTED_ARGS = (_Connection.credentials,) + EXPECTED_KWARGS = { + 'api_access_endpoint': 'https://storage.googleapis.com', + 'expiration': EXPIRATION, + 'method': 'GET', + 'resource': PATH, + } + self.assertEqual(SIGNER._signed, [(EXPECTED_ARGS, EXPECTED_KWARGS)]) def test_generate_signed_url_w_slash_in_name(self): + from gcloud._testing import _Monkey + from gcloud.storage import blob as MUT + BLOB_NAME = 'parent/child' EXPIRATION = '2014-10-16T20:34:37Z' connection = _Connection() @@ -133,12 +149,24 @@ def test_generate_signed_url_w_slash_in_name(self): blob = self._makeOne(bucket, BLOB_NAME) URI = ('http://example.com/abucket/a-blob-name?Signature=DEADBEEF' '&Expiration=2014-10-16T20:34:37Z') - self.assertEqual(blob.generate_signed_url(EXPIRATION), URI) - self.assertEqual(connection._signed, - [('/name/parent%2Fchild', - EXPIRATION, {'method': 'GET'})]) + + SIGNER = _Signer() + with _Monkey(MUT, generate_signed_url=SIGNER): + self.assertEqual(blob.generate_signed_url(EXPIRATION), URI) + + EXPECTED_ARGS = (_Connection.credentials,) + EXPECTED_KWARGS = { + 'api_access_endpoint': 'https://storage.googleapis.com', + 'expiration': EXPIRATION, + 'method': 'GET', + 'resource': '/name/parent%2Fchild', + } + self.assertEqual(SIGNER._signed, [(EXPECTED_ARGS, EXPECTED_KWARGS)]) def test_generate_signed_url_w_explicit_method(self): + from gcloud._testing import _Monkey + from gcloud.storage import blob as MUT + BLOB_NAME = 'blob-name' EXPIRATION = '2014-10-16T20:34:37Z' connection = _Connection() @@ -146,11 +174,21 @@ def test_generate_signed_url_w_explicit_method(self): blob = self._makeOne(bucket, BLOB_NAME) URI = ('http://example.com/abucket/a-blob-name?Signature=DEADBEEF' '&Expiration=2014-10-16T20:34:37Z') - self.assertEqual(blob.generate_signed_url(EXPIRATION, method='POST'), - URI) + + SIGNER = _Signer() + with _Monkey(MUT, generate_signed_url=SIGNER): + self.assertEqual( + blob.generate_signed_url(EXPIRATION, method='POST'), URI) + PATH = '/name/%s' % (BLOB_NAME,) - self.assertEqual(connection._signed, - [(PATH, EXPIRATION, {'method': 'POST'})]) + EXPECTED_ARGS = (_Connection.credentials,) + EXPECTED_KWARGS = { + 'api_access_endpoint': 'https://storage.googleapis.com', + 'expiration': EXPIRATION, + 'method': 'POST', + 'resource': PATH, + } + self.assertEqual(SIGNER._signed, [(EXPECTED_ARGS, EXPECTED_KWARGS)]) def test_exists_miss(self): NONESUCH = 'nonesuch' @@ -825,6 +863,7 @@ class _Connection(_Responder): API_BASE_URL = 'http://example.com' USER_AGENT = 'testing 1.2.3' + credentials = object() def __init__(self, *responses): super(_Connection, self).__init__(*responses) @@ -846,11 +885,6 @@ def build_api_url(self, path, query_params=None, scheme, netloc, _, _, _ = urlsplit(api_base_url) return urlunsplit((scheme, netloc, path, qs, '')) - def generate_signed_url(self, resource, expiration, **kw): - self._signed.append((resource, expiration, kw)) - return ('http://example.com/abucket/a-blob-name?Signature=DEADBEEF' - '&Expiration=%s' % expiration) - class _HTTP(_Responder): @@ -879,3 +913,14 @@ def copy_blob(self, blob, destination_bucket, new_name): def delete_blob(self, blob): del self._blobs[blob.name] self._deleted.append(blob.name) + + +class _Signer(object): + + def __init__(self): + self._signed = [] + + def __call__(self, *args, **kwargs): + self._signed.append((args, kwargs)) + return ('http://example.com/abucket/a-blob-name?Signature=DEADBEEF' + '&Expiration=%s' % kwargs.get('expiration')) diff --git a/gcloud/storage/test_connection.py b/gcloud/storage/test_connection.py index 1e10f1928cd4..ecdc0c814e2a 100644 --- a/gcloud/storage/test_connection.py +++ b/gcloud/storage/test_connection.py @@ -602,43 +602,6 @@ def test_new_bucket_w_invalid(self): conn = self._makeOne(PROJECT) self.assertRaises(TypeError, conn.new_bucket, object()) - def test_generate_signed_url_w_expiration_int(self): - import base64 - import urlparse - from gcloud._testing import _Monkey - from gcloud.test_credentials import _Credentials - from gcloud.storage import connection as MUT - - ENDPOINT = 'http://api.example.com' - RESOURCE = '/name/path' - PROJECT = 'project' - SIGNED = base64.b64encode('DEADBEEF') - conn = self._makeOne(PROJECT, _Credentials()) - conn.API_ACCESS_ENDPOINT = ENDPOINT - - def _get_signed_query_params(*args): - credentials, expiration = args[:2] - return { - 'GoogleAccessId': credentials.service_account_name, - 'Expires': str(expiration), - 'Signature': SIGNED, - } - - with _Monkey(MUT, _get_signed_query_params=_get_signed_query_params): - url = conn.generate_signed_url(RESOURCE, 1000) - - scheme, netloc, path, qs, frag = urlparse.urlsplit(url) - self.assertEqual(scheme, 'http') - self.assertEqual(netloc, 'api.example.com') - self.assertEqual(path, RESOURCE) - params = urlparse.parse_qs(qs) - self.assertEqual(len(params), 3) - self.assertEqual(params['Signature'], [SIGNED]) - self.assertEqual(params['Expires'], ['1000']) - self.assertEqual(params['GoogleAccessId'], - [_Credentials.service_account_name]) - self.assertEqual(frag, '') - class Test__BucketIterator(unittest2.TestCase): @@ -676,225 +639,6 @@ def test_get_items_from_response_non_empty(self): self.assertEqual(bucket.name, BLOB_NAME) -class Test__get_expiration_seconds(unittest2.TestCase): - - def _callFUT(self, expiration): - from gcloud.storage.connection import _get_expiration_seconds - - return _get_expiration_seconds(expiration) - - def _utc_seconds(self, when): - import calendar - - return int(calendar.timegm(when.timetuple())) - - def test__get_expiration_seconds_w_invalid(self): - self.assertRaises(TypeError, self._callFUT, object()) - self.assertRaises(TypeError, self._callFUT, None) - - def test__get_expiration_seconds_w_int(self): - self.assertEqual(self._callFUT(123), 123) - - def test__get_expiration_seconds_w_long(self): - try: - long - except NameError: # pragma: NO COVER Py3K - pass - else: - self.assertEqual(self._callFUT(long(123)), 123) - - def test__get_expiration_w_naive_datetime(self): - import datetime - - expiration_no_tz = datetime.datetime(2004, 8, 19, 0, 0, 0, 0) - utc_seconds = self._utc_seconds(expiration_no_tz) - self.assertEqual(self._callFUT(expiration_no_tz), utc_seconds) - - def test__get_expiration_w_utc_datetime(self): - import datetime - import pytz - - expiration_utc = datetime.datetime(2004, 8, 19, 0, 0, 0, 0, pytz.utc) - utc_seconds = self._utc_seconds(expiration_utc) - self.assertEqual(self._callFUT(expiration_utc), utc_seconds) - - def test__get_expiration_w_other_zone_datetime(self): - import datetime - import pytz - - zone = pytz.timezone('CET') - expiration_other = datetime.datetime(2004, 8, 19, 0, 0, 0, 0, zone) - utc_seconds = self._utc_seconds(expiration_other) - cet_seconds = utc_seconds - (60 * 60) # CET one hour earlier than UTC - self.assertEqual(self._callFUT(expiration_other), cet_seconds) - - def test__get_expiration_seconds_w_timedelta_seconds(self): - import datetime - from gcloud.storage import connection - from gcloud._testing import _Monkey - - dummy_utcnow = datetime.datetime(2004, 8, 19, 0, 0, 0, 0) - utc_seconds = self._utc_seconds(dummy_utcnow) - expiration_as_delta = datetime.timedelta(seconds=10) - - with _Monkey(connection, _utcnow=lambda: dummy_utcnow): - result = self._callFUT(expiration_as_delta) - - self.assertEqual(result, utc_seconds + 10) - - def test__get_expiration_seconds_w_timedelta_days(self): - import datetime - from gcloud.storage import connection - from gcloud._testing import _Monkey - - dummy_utcnow = datetime.datetime(2004, 8, 19, 0, 0, 0, 0) - utc_seconds = self._utc_seconds(dummy_utcnow) - expiration_as_delta = datetime.timedelta(days=1) - - with _Monkey(connection, _utcnow=lambda: dummy_utcnow): - result = self._callFUT(expiration_as_delta) - - self.assertEqual(result, utc_seconds + 86400) - - -class Test__get_pem_key(unittest2.TestCase): - - def _callFUT(self, credentials): - from gcloud.storage.connection import _get_pem_key - return _get_pem_key(credentials) - - def test_bad_argument(self): - self.assertRaises(TypeError, self._callFUT, None) - - def test_signed_jwt_for_p12(self): - import base64 - from oauth2client import client - from gcloud._testing import _Monkey - from gcloud.storage import connection as MUT - - scopes = [] - PRIVATE_KEY = 'dummy_private_key_text' - credentials = client.SignedJwtAssertionCredentials( - 'dummy_service_account_name', PRIVATE_KEY, scopes) - crypt = _Crypt() - rsa = _RSA() - with _Monkey(MUT, crypt=crypt, RSA=rsa): - result = self._callFUT(credentials) - - self.assertEqual(crypt._private_key_text, - base64.b64encode(PRIVATE_KEY)) - self.assertEqual(crypt._private_key_password, 'notasecret') - self.assertEqual(result, 'imported:__PEM__') - - def test_service_account_via_json_key(self): - from oauth2client import service_account - from gcloud._testing import _Monkey - from gcloud.storage import connection as MUT - - scopes = [] - - PRIVATE_TEXT = 'dummy_private_key_pkcs8_text' - - def _get_private_key(private_key_pkcs8_text): - return private_key_pkcs8_text - - with _Monkey(service_account, _get_private_key=_get_private_key): - credentials = service_account._ServiceAccountCredentials( - 'dummy_service_account_id', 'dummy_service_account_email', - 'dummy_private_key_id', PRIVATE_TEXT, scopes) - - rsa = _RSA() - with _Monkey(MUT, RSA=rsa): - result = self._callFUT(credentials) - - expected = 'imported:%s' % (PRIVATE_TEXT,) - self.assertEqual(result, expected) - - -class Test__get_signed_query_params(unittest2.TestCase): - - def _callFUT(self, credentials, expiration, signature_string): - from gcloud.storage.connection import _get_signed_query_params - return _get_signed_query_params(credentials, expiration, - signature_string) - - def test_wrong_type(self): - from gcloud._testing import _Monkey - from gcloud.storage import connection as MUT - - pkcs_v1_5 = _PKCS1_v1_5() - rsa = _RSA() - sha256 = _SHA256() - - def _get_pem_key(credentials): - return credentials - - BAD_CREDENTIALS = None - EXPIRATION = '100' - SIGNATURE_STRING = 'dummy_signature' - with _Monkey(MUT, RSA=rsa, PKCS1_v1_5=pkcs_v1_5, - SHA256=sha256, _get_pem_key=_get_pem_key): - self.assertRaises(NameError, self._callFUT, - BAD_CREDENTIALS, EXPIRATION, SIGNATURE_STRING) - - def _run_test_with_credentials(self, credentials, account_name): - import base64 - from gcloud._testing import _Monkey - from gcloud.storage import connection as MUT - - crypt = _Crypt() - pkcs_v1_5 = _PKCS1_v1_5() - rsa = _RSA() - sha256 = _SHA256() - - EXPIRATION = '100' - SIGNATURE_STRING = 'dummy_signature' - with _Monkey(MUT, crypt=crypt, RSA=rsa, PKCS1_v1_5=pkcs_v1_5, - SHA256=sha256): - result = self._callFUT(credentials, EXPIRATION, SIGNATURE_STRING) - - if crypt._pkcs12_key_as_pem_called: - self.assertEqual(crypt._private_key_text, - base64.b64encode('dummy_private_key_text')) - self.assertEqual(crypt._private_key_password, 'notasecret') - self.assertEqual(sha256._signature_string, SIGNATURE_STRING) - SIGNED = base64.b64encode('DEADBEEF') - expected_query = { - 'Expires': EXPIRATION, - 'GoogleAccessId': account_name, - 'Signature': SIGNED, - } - self.assertEqual(result, expected_query) - - def test_signed_jwt_for_p12(self): - from oauth2client import client - - scopes = [] - ACCOUNT_NAME = 'dummy_service_account_name' - credentials = client.SignedJwtAssertionCredentials( - ACCOUNT_NAME, 'dummy_private_key_text', scopes) - self._run_test_with_credentials(credentials, ACCOUNT_NAME) - - def test_service_account_via_json_key(self): - from oauth2client import service_account - from gcloud._testing import _Monkey - - scopes = [] - - PRIVATE_TEXT = 'dummy_private_key_pkcs8_text' - - def _get_private_key(private_key_pkcs8_text): - return private_key_pkcs8_text - - ACCOUNT_NAME = 'dummy_service_account_email' - with _Monkey(service_account, _get_private_key=_get_private_key): - credentials = service_account._ServiceAccountCredentials( - 'dummy_service_account_id', ACCOUNT_NAME, - 'dummy_private_key_id', PRIVATE_TEXT, scopes) - - self._run_test_with_credentials(credentials, ACCOUNT_NAME) - - class Http(object): _called_with = None @@ -907,45 +651,3 @@ def __init__(self, headers, content): def request(self, **kw): self._called_with = kw return self._response, self._content - - -class _Crypt(object): - - _pkcs12_key_as_pem_called = False - - def pkcs12_key_as_pem(self, private_key_text, private_key_password): - self._pkcs12_key_as_pem_called = True - self._private_key_text = private_key_text - self._private_key_password = private_key_password - return '__PEM__' - - -class _RSA(object): - - _imported = None - - def importKey(self, pem): - self._imported = pem - return 'imported:%s' % pem - - -class _PKCS1_v1_5(object): - - _pem_key = _signature_hash = None - - def new(self, pem_key): - self._pem_key = pem_key - return self - - def sign(self, signature_hash): - self._signature_hash = signature_hash - return 'DEADBEEF' - - -class _SHA256(object): - - _signature_string = None - - def new(self, signature_string): - self._signature_string = signature_string - return self diff --git a/gcloud/test_credentials.py b/gcloud/test_credentials.py index 50cb246c2e05..0f0176bded11 100644 --- a/gcloud/test_credentials.py +++ b/gcloud/test_credentials.py @@ -61,6 +61,265 @@ def test_get_for_service_account_p12_w_scope(self): self.assertEqual(client._called_with, expected_called_with) +class Test_generate_signed_url(unittest2.TestCase): + + def _callFUT(self, *args, **kwargs): + from gcloud.credentials import generate_signed_url + return generate_signed_url(*args, **kwargs) + + def test_w_expiration_int(self): + import base64 + import urlparse + from gcloud._testing import _Monkey + from gcloud import credentials as MUT + + ENDPOINT = 'http://api.example.com' + RESOURCE = '/name/path' + SIGNED = base64.b64encode('DEADBEEF') + CREDENTIALS = _Credentials() + + def _get_signed_query_params(*args): + credentials, expiration = args[:2] + return { + 'GoogleAccessId': credentials.service_account_name, + 'Expires': str(expiration), + 'Signature': SIGNED, + } + + with _Monkey(MUT, _get_signed_query_params=_get_signed_query_params): + url = self._callFUT(CREDENTIALS, RESOURCE, 1000, + api_access_endpoint=ENDPOINT) + + scheme, netloc, path, qs, frag = urlparse.urlsplit(url) + self.assertEqual(scheme, 'http') + self.assertEqual(netloc, 'api.example.com') + self.assertEqual(path, RESOURCE) + params = urlparse.parse_qs(qs) + self.assertEqual(len(params), 3) + self.assertEqual(params['Signature'], [SIGNED]) + self.assertEqual(params['Expires'], ['1000']) + self.assertEqual(params['GoogleAccessId'], + [_Credentials.service_account_name]) + self.assertEqual(frag, '') + + +class Test__get_signed_query_params(unittest2.TestCase): + + def _callFUT(self, credentials, expiration, signature_string): + from gcloud.credentials import _get_signed_query_params + return _get_signed_query_params(credentials, expiration, + signature_string) + + def test_wrong_type(self): + from gcloud._testing import _Monkey + from gcloud import credentials as MUT + + pkcs_v1_5 = _PKCS1_v1_5() + rsa = _RSA() + sha256 = _SHA256() + + def _get_pem_key(credentials): + return credentials + + BAD_CREDENTIALS = None + EXPIRATION = '100' + SIGNATURE_STRING = 'dummy_signature' + with _Monkey(MUT, RSA=rsa, PKCS1_v1_5=pkcs_v1_5, + SHA256=sha256, _get_pem_key=_get_pem_key): + self.assertRaises(NameError, self._callFUT, + BAD_CREDENTIALS, EXPIRATION, SIGNATURE_STRING) + + def _run_test_with_credentials(self, credentials, account_name): + import base64 + from gcloud._testing import _Monkey + from gcloud import credentials as MUT + + crypt = _Crypt() + pkcs_v1_5 = _PKCS1_v1_5() + rsa = _RSA() + sha256 = _SHA256() + + EXPIRATION = '100' + SIGNATURE_STRING = 'dummy_signature' + with _Monkey(MUT, crypt=crypt, RSA=rsa, PKCS1_v1_5=pkcs_v1_5, + SHA256=sha256): + result = self._callFUT(credentials, EXPIRATION, SIGNATURE_STRING) + + if crypt._pkcs12_key_as_pem_called: + self.assertEqual(crypt._private_key_text, + base64.b64encode('dummy_private_key_text')) + self.assertEqual(crypt._private_key_password, 'notasecret') + self.assertEqual(sha256._signature_string, SIGNATURE_STRING) + SIGNED = base64.b64encode('DEADBEEF') + expected_query = { + 'Expires': EXPIRATION, + 'GoogleAccessId': account_name, + 'Signature': SIGNED, + } + self.assertEqual(result, expected_query) + + def test_signed_jwt_for_p12(self): + from oauth2client import client + + scopes = [] + ACCOUNT_NAME = 'dummy_service_account_name' + credentials = client.SignedJwtAssertionCredentials( + ACCOUNT_NAME, 'dummy_private_key_text', scopes) + self._run_test_with_credentials(credentials, ACCOUNT_NAME) + + def test_service_account_via_json_key(self): + from oauth2client import service_account + from gcloud._testing import _Monkey + + scopes = [] + + PRIVATE_TEXT = 'dummy_private_key_pkcs8_text' + + def _get_private_key(private_key_pkcs8_text): + return private_key_pkcs8_text + + ACCOUNT_NAME = 'dummy_service_account_email' + with _Monkey(service_account, _get_private_key=_get_private_key): + credentials = service_account._ServiceAccountCredentials( + 'dummy_service_account_id', ACCOUNT_NAME, + 'dummy_private_key_id', PRIVATE_TEXT, scopes) + + self._run_test_with_credentials(credentials, ACCOUNT_NAME) + + +class Test__get_pem_key(unittest2.TestCase): + + def _callFUT(self, credentials): + from gcloud.credentials import _get_pem_key + return _get_pem_key(credentials) + + def test_bad_argument(self): + self.assertRaises(TypeError, self._callFUT, None) + + def test_signed_jwt_for_p12(self): + import base64 + from oauth2client import client + from gcloud._testing import _Monkey + from gcloud import credentials as MUT + + scopes = [] + PRIVATE_KEY = 'dummy_private_key_text' + credentials = client.SignedJwtAssertionCredentials( + 'dummy_service_account_name', PRIVATE_KEY, scopes) + crypt = _Crypt() + rsa = _RSA() + with _Monkey(MUT, crypt=crypt, RSA=rsa): + result = self._callFUT(credentials) + + self.assertEqual(crypt._private_key_text, + base64.b64encode(PRIVATE_KEY)) + self.assertEqual(crypt._private_key_password, 'notasecret') + self.assertEqual(result, 'imported:__PEM__') + + def test_service_account_via_json_key(self): + from oauth2client import service_account + from gcloud._testing import _Monkey + from gcloud import credentials as MUT + + scopes = [] + + PRIVATE_TEXT = 'dummy_private_key_pkcs8_text' + + def _get_private_key(private_key_pkcs8_text): + return private_key_pkcs8_text + + with _Monkey(service_account, _get_private_key=_get_private_key): + credentials = service_account._ServiceAccountCredentials( + 'dummy_service_account_id', 'dummy_service_account_email', + 'dummy_private_key_id', PRIVATE_TEXT, scopes) + + rsa = _RSA() + with _Monkey(MUT, RSA=rsa): + result = self._callFUT(credentials) + + expected = 'imported:%s' % (PRIVATE_TEXT,) + self.assertEqual(result, expected) + + +class Test__get_expiration_seconds(unittest2.TestCase): + + def _callFUT(self, expiration): + from gcloud.credentials import _get_expiration_seconds + return _get_expiration_seconds(expiration) + + def _utc_seconds(self, when): + import calendar + return int(calendar.timegm(when.timetuple())) + + def test_w_invalid(self): + self.assertRaises(TypeError, self._callFUT, object()) + self.assertRaises(TypeError, self._callFUT, None) + + def test_w_int(self): + self.assertEqual(self._callFUT(123), 123) + + def test_w_long(self): + try: + long + except NameError: # pragma: NO COVER Py3K + pass + else: + self.assertEqual(self._callFUT(long(123)), 123) + + def test_w_naive_datetime(self): + import datetime + + expiration_no_tz = datetime.datetime(2004, 8, 19, 0, 0, 0, 0) + utc_seconds = self._utc_seconds(expiration_no_tz) + self.assertEqual(self._callFUT(expiration_no_tz), utc_seconds) + + def test_w_utc_datetime(self): + import datetime + import pytz + + expiration_utc = datetime.datetime(2004, 8, 19, 0, 0, 0, 0, pytz.utc) + utc_seconds = self._utc_seconds(expiration_utc) + self.assertEqual(self._callFUT(expiration_utc), utc_seconds) + + def test_w_other_zone_datetime(self): + import datetime + import pytz + + zone = pytz.timezone('CET') + expiration_other = datetime.datetime(2004, 8, 19, 0, 0, 0, 0, zone) + utc_seconds = self._utc_seconds(expiration_other) + cet_seconds = utc_seconds - (60 * 60) # CET one hour earlier than UTC + self.assertEqual(self._callFUT(expiration_other), cet_seconds) + + def test_w_timedelta_seconds(self): + import datetime + from gcloud._testing import _Monkey + from gcloud import credentials as MUT + + dummy_utcnow = datetime.datetime(2004, 8, 19, 0, 0, 0, 0) + utc_seconds = self._utc_seconds(dummy_utcnow) + expiration_as_delta = datetime.timedelta(seconds=10) + + with _Monkey(MUT, _utcnow=lambda: dummy_utcnow): + result = self._callFUT(expiration_as_delta) + + self.assertEqual(result, utc_seconds + 10) + + def test_w_timedelta_days(self): + import datetime + from gcloud._testing import _Monkey + from gcloud import credentials as MUT + + dummy_utcnow = datetime.datetime(2004, 8, 19, 0, 0, 0, 0) + utc_seconds = self._utc_seconds(dummy_utcnow) + expiration_as_delta = datetime.timedelta(days=1) + + with _Monkey(MUT, _utcnow=lambda: dummy_utcnow): + result = self._callFUT(expiration_as_delta) + + self.assertEqual(result, utc_seconds + 86400) + + class _Credentials(object): service_account_name = 'testing@example.com' @@ -85,3 +344,45 @@ def get_application_default(): def SignedJwtAssertionCredentials(self, **kw): self._called_with = kw return self._signed + + +class _Crypt(object): + + _pkcs12_key_as_pem_called = False + + def pkcs12_key_as_pem(self, private_key_text, private_key_password): + self._pkcs12_key_as_pem_called = True + self._private_key_text = private_key_text + self._private_key_password = private_key_password + return '__PEM__' + + +class _RSA(object): + + _imported = None + + def importKey(self, pem): + self._imported = pem + return 'imported:%s' % pem + + +class _PKCS1_v1_5(object): + + _pem_key = _signature_hash = None + + def new(self, pem_key): + self._pem_key = pem_key + return self + + def sign(self, signature_hash): + self._signature_hash = signature_hash + return 'DEADBEEF' + + +class _SHA256(object): + + _signature_string = None + + def new(self, signature_string): + self._signature_string = signature_string + return self