From e6ff48f4147690cccf4564deaa5a5c41651ea3ee Mon Sep 17 00:00:00 2001 From: Danny Hermes Date: Mon, 10 Aug 2015 14:39:16 -0700 Subject: [PATCH] Unifying all conversions to bytes. Also making sure the _urlsafe_b64encode returns bytes and updating dependent code with the change in return type. --- oauth2client/_helpers.py | 37 ++++++++++++++++++++++++++++----- oauth2client/_openssl_crypt.py | 16 ++++++-------- oauth2client/_pycrypto_crypt.py | 10 ++++----- oauth2client/client.py | 25 +++++++++++----------- oauth2client/crypt.py | 10 +++++---- oauth2client/devshell.py | 4 ++-- oauth2client/service_account.py | 14 +++++-------- tests/test__helpers.py | 23 ++++++++++++++++++-- tests/test_jwt.py | 2 +- 9 files changed, 90 insertions(+), 51 deletions(-) diff --git a/oauth2client/_helpers.py b/oauth2client/_helpers.py index 974d21023..0e20fcd2e 100644 --- a/oauth2client/_helpers.py +++ b/oauth2client/_helpers.py @@ -39,15 +39,42 @@ def _json_encode(data): return json.dumps(data, separators=(',', ':')) +def _to_bytes(value, encoding='ascii'): + """Converts a string value to bytes, if necessary. + + Unfortunately, ``six.b`` is insufficient for this task since in + Python2 it does not modify ``unicode`` objects. + + Args: + value: The string/bytes value to be converted. + encoding: The encoding to use to convert unicode to bytes. Defaults + to "ascii", which will not allow any characters from ordinals + larger than 127. Other useful values are "latin-1", which + which will only allows byte ordinals (up to 255) and "utf-8", + which will encode any unicode that needs to be. + + Returns: + The original value converted to bytes (if unicode) or as passed in + if it started out as bytes. + + Raises: + ValueError if the value could not be converted to bytes. + """ + result = value + if isinstance(value, six.text_type): + result = value.encode(encoding) + if not isinstance(result, six.binary_type): + raise ValueError('%r could not be converted to bytes' % (value,)) + return result + + def _urlsafe_b64encode(raw_bytes): - if isinstance(raw_bytes, six.text_type): - raw_bytes = raw_bytes.encode('utf-8') - return base64.urlsafe_b64encode(raw_bytes).decode('ascii').rstrip('=') + raw_bytes = _to_bytes(raw_bytes, encoding='utf-8') + return base64.urlsafe_b64encode(raw_bytes).rstrip(b'=') def _urlsafe_b64decode(b64string): # Guard against unicode strings, which base64 can't handle. - if isinstance(b64string, six.text_type): - b64string = b64string.encode('ascii') + b64string = _to_bytes(b64string) padded = b64string + b'=' * (4 - len(b64string) % 4) return base64.urlsafe_b64decode(padded) diff --git a/oauth2client/_openssl_crypt.py b/oauth2client/_openssl_crypt.py index 6da310924..9fcd996e8 100644 --- a/oauth2client/_openssl_crypt.py +++ b/oauth2client/_openssl_crypt.py @@ -18,6 +18,7 @@ from OpenSSL import crypto from oauth2client._helpers import _parse_pem_key +from oauth2client._helpers import _to_bytes class OpenSSLVerifier(object): @@ -44,10 +45,8 @@ def verify(self, message, signature): True if message was signed by the private key associated with the public key that this object was constructed with. """ - if isinstance(message, six.text_type): - message = message.encode('utf-8') - if isinstance(signature, six.text_type): - signature = signature.encode('utf-8') + message = _to_bytes(message, encoding='utf-8') + signature = _to_bytes(signature, encoding='utf-8') try: crypto.verify(self._pubkey, signature, message, 'sha256') return True @@ -96,8 +95,7 @@ def sign(self, message): Returns: string, The signature of the message for the given key. """ - if isinstance(message, six.text_type): - message = message.encode('utf-8') + message = _to_bytes(message, encoding='utf-8') return crypto.sign(self._key, message, 'sha256') @staticmethod @@ -118,8 +116,7 @@ def from_string(key, password=b'notasecret'): if parsed_pem_key: pkey = crypto.load_privatekey(crypto.FILETYPE_PEM, parsed_pem_key) else: - if isinstance(password, six.text_type): - password = password.encode('utf-8') + password = _to_bytes(password, encoding='utf-8') pkey = crypto.load_pkcs12(key, password).get_privatekey() return OpenSSLSigner(pkey) @@ -135,8 +132,7 @@ def pkcs12_key_as_pem(private_key_text, private_key_password): String. PEM contents of ``private_key_text``. """ decoded_body = base64.b64decode(private_key_text) - if isinstance(private_key_password, six.text_type): - private_key_password = private_key_password.encode('ascii') + private_key_password = _to_bytes(private_key_password) pkcs12 = crypto.load_pkcs12(decoded_body, private_key_password) return crypto.dump_privatekey(crypto.FILETYPE_PEM, diff --git a/oauth2client/_pycrypto_crypt.py b/oauth2client/_pycrypto_crypt.py index 530eea22d..957643810 100644 --- a/oauth2client/_pycrypto_crypt.py +++ b/oauth2client/_pycrypto_crypt.py @@ -20,6 +20,7 @@ import six from oauth2client._helpers import _parse_pem_key +from oauth2client._helpers import _to_bytes from oauth2client._helpers import _urlsafe_b64decode @@ -46,8 +47,7 @@ def verify(self, message, signature): True if message was signed by the private key associated with the public key that this object was constructed with. """ - if isinstance(message, six.text_type): - message = message.encode('utf-8') + message = _to_bytes(message, encoding='utf-8') return PKCS1_v1_5.new(self._pubkey).verify( SHA256.new(message), signature) @@ -64,8 +64,7 @@ def from_string(key_pem, is_x509_cert): Verifier instance. """ if is_x509_cert: - if isinstance(key_pem, six.text_type): - key_pem = key_pem.encode('ascii') + key_pem = _to_bytes(key_pem) pemLines = key_pem.replace(b' ', b'').split() certDer = _urlsafe_b64decode(b''.join(pemLines[1:-1])) certSeq = DerSequence() @@ -98,8 +97,7 @@ def sign(self, message): Returns: string, The signature of the message for the given key. """ - if isinstance(message, six.text_type): - message = message.encode('utf-8') + message = _to_bytes(message, encoding='utf-8') return PKCS1_v1_5.new(self._key).sign(SHA256.new(message)) @staticmethod diff --git a/oauth2client/client.py b/oauth2client/client.py index 3952d4098..dd3964e0d 100644 --- a/oauth2client/client.py +++ b/oauth2client/client.py @@ -40,6 +40,7 @@ from oauth2client import GOOGLE_REVOKE_URI from oauth2client import GOOGLE_TOKEN_URI from oauth2client import GOOGLE_TOKEN_INFO_URI +from oauth2client._helpers import _to_bytes from oauth2client._helpers import _urlsafe_b64decode from oauth2client import clientsecrets from oauth2client import util @@ -278,7 +279,7 @@ def new_from_json(cls, s): An instance of the subclass of Credentials that was serialized with to_json(). """ - if six.PY3 and isinstance(s, bytes): + if isinstance(s, bytes): s = s.decode('utf-8') data = json.loads(s) # Find and call the right classmethod from_json() to restore the object. @@ -425,11 +426,13 @@ def clean_headers(headers): clean = {} try: for k, v in six.iteritems(headers): - clean_k = k if isinstance(k, bytes) else str(k).encode('ascii') - clean_v = v if isinstance(v, bytes) else str(v).encode('ascii') - clean[clean_k] = clean_v + if not isinstance(k, six.binary_type): + k = str(k) + if not isinstance(v, six.binary_type): + v = str(v) + clean[_to_bytes(k)] = _to_bytes(v) except UnicodeEncodeError: - raise NonAsciiHeaderError(k + ': ' + v) + raise NonAsciiHeaderError(k, ': ', v) return clean @@ -670,7 +673,7 @@ def from_json(cls, s): Returns: An instance of a Credentials subclass. """ - if six.PY3 and isinstance(s, bytes): + if isinstance(s, bytes): s = s.decode('utf-8') data = json.loads(s) if (data.get('token_expiry') and @@ -842,7 +845,7 @@ def _do_refresh_request(self, http_request): logger.info('Refreshing access_token') resp, content = http_request( self.token_uri, method='POST', body=body, headers=headers) - if six.PY3 and isinstance(content, bytes): + if isinstance(content, bytes): content = content.decode('utf-8') if resp.status == 200: d = json.loads(content) @@ -903,7 +906,7 @@ def _do_revoke(self, http_request, token): token_revoke_uri = _update_query_params(self.revoke_uri, query_params) resp, content = http_request(token_revoke_uri) - if six.PY3 and isinstance(content, bytes): + if isinstance(content, bytes): content = content.decode('utf-8') if resp.status == 200: @@ -1015,7 +1018,7 @@ def __init__(self, access_token, user_agent, revoke_uri=None): @classmethod def from_json(cls, s): - if six.PY3 and isinstance(s, bytes): + if isinstance(s, bytes): s = s.decode('utf-8') data = json.loads(s) retval = AccessTokenCredentials( @@ -1602,9 +1605,7 @@ def __init__(self, # Keep base64 encoded so it can be stored in JSON. self.private_key = base64.b64encode(private_key) - if isinstance(self.private_key, six.text_type): - self.private_key = self.private_key.encode('utf-8') - + self.private_key = _to_bytes(self.private_key, encoding='utf-8') self.private_key_password = private_key_password self.service_account_name = service_account_name self.kwargs = kwargs diff --git a/oauth2client/crypt.py b/oauth2client/crypt.py index dc7e7608f..f88c790b4 100644 --- a/oauth2client/crypt.py +++ b/oauth2client/crypt.py @@ -20,6 +20,7 @@ import time from oauth2client._helpers import _json_encode +from oauth2client._helpers import _to_bytes from oauth2client._helpers import _urlsafe_b64decode from oauth2client._helpers import _urlsafe_b64encode @@ -84,14 +85,14 @@ def make_signed_jwt(signer, payload): _urlsafe_b64encode(_json_encode(header)), _urlsafe_b64encode(_json_encode(payload)), ] - signing_input = '.'.join(segments) + signing_input = b'.'.join(segments) signature = signer.sign(signing_input) segments.append(_urlsafe_b64encode(signature)) logger.debug(str(segments)) - return '.'.join(segments) + return b'.'.join(segments) def verify_signed_jwt_with_certs(jwt, certs, audience): @@ -111,11 +112,12 @@ def verify_signed_jwt_with_certs(jwt, certs, audience): Raises: AppIdentityError if any checks are failed. """ - segments = jwt.split('.') + jwt = _to_bytes(jwt) + segments = jwt.split(b'.') if len(segments) != 3: raise AppIdentityError('Wrong number of segments in token: %s' % jwt) - signed = '%s.%s' % (segments[0], segments[1]) + signed = segments[0] + b'.' + segments[1] signature = _urlsafe_b64decode(segments[2]) diff --git a/oauth2client/devshell.py b/oauth2client/devshell.py index a33de8717..52eb26041 100644 --- a/oauth2client/devshell.py +++ b/oauth2client/devshell.py @@ -17,6 +17,7 @@ import json import os +from oauth2client._helpers import _to_bytes from oauth2client import client @@ -76,7 +77,7 @@ def _SendRecv(): data = CREDENTIAL_INFO_REQUEST_JSON msg = '%s\n%s' % (len(data), data) - sock.sendall(msg.encode()) + sock.sendall(_to_bytes(msg, encoding='utf-8')) header = sock.recv(6).decode() if '\n' not in header: @@ -133,4 +134,3 @@ def from_json(cls, json_data): def serialization_data(self): raise NotImplementedError( 'Cannot serialize Developer Shell credentials.') - diff --git a/oauth2client/service_account.py b/oauth2client/service_account.py index ef75a00f3..1d6495a75 100644 --- a/oauth2client/service_account.py +++ b/oauth2client/service_account.py @@ -28,6 +28,7 @@ from oauth2client import GOOGLE_REVOKE_URI from oauth2client import GOOGLE_TOKEN_URI from oauth2client._helpers import _json_encode +from oauth2client._helpers import _to_bytes from oauth2client._helpers import _urlsafe_b64encode from oauth2client import util from oauth2client.client import AssertionCredentials @@ -76,8 +77,8 @@ def _generate_assertion(self): } payload.update(self._kwargs) - first_segment = _urlsafe_b64encode(_json_encode(header)).encode('utf-8') - second_segment = _urlsafe_b64encode(_json_encode(payload)).encode('utf-8') + first_segment = _urlsafe_b64encode(_json_encode(header)) + second_segment = _urlsafe_b64encode(_json_encode(payload)) assertion_input = first_segment + b'.' + second_segment # Sign the assertion. @@ -88,10 +89,7 @@ def _generate_assertion(self): def sign_blob(self, blob): # Ensure that it is bytes - try: - blob = blob.encode('utf-8') - except AttributeError: - pass + blob = _to_bytes(blob, encoding='utf-8') return (self._private_key_id, rsa.pkcs1.sign(blob, self._private_key, 'SHA-256')) @@ -126,9 +124,7 @@ def create_scoped(self, scopes): def _get_private_key(private_key_pkcs8_text): """Get an RSA private key object from a pkcs8 representation.""" - - if not isinstance(private_key_pkcs8_text, six.binary_type): - private_key_pkcs8_text = private_key_pkcs8_text.encode('ascii') + private_key_pkcs8_text = _to_bytes(private_key_pkcs8_text) der = rsa.pem.load_pem(private_key_pkcs8_text, 'PRIVATE KEY') asn1_private_key, _ = decoder.decode(der, asn1Spec=PrivateKeyInfo()) return rsa.PrivateKey.load_pkcs1( diff --git a/tests/test__helpers.py b/tests/test__helpers.py index f476a1310..f86a845c2 100644 --- a/tests/test__helpers.py +++ b/tests/test__helpers.py @@ -17,6 +17,7 @@ from oauth2client._helpers import _json_encode from oauth2client._helpers import _parse_pem_key +from oauth2client._helpers import _to_bytes from oauth2client._helpers import _urlsafe_b64decode from oauth2client._helpers import _urlsafe_b64encode @@ -49,17 +50,35 @@ def test_list_input(self): self.assertEqual(result, """[42,1337]""") +class Test__to_bytes(unittest.TestCase): + + def test_with_bytes(self): + value = b'bytes-val' + self.assertEqual(_to_bytes(value), value) + + def test_with_unicode(self): + value = u'string-val' + encoded_value = b'string-val' + self.assertEqual(_to_bytes(value), encoded_value) + + def test_with_nonstring_type(self): + value = object() + self.assertRaises(ValueError, _to_bytes, value) + + class Test__urlsafe_b64encode(unittest.TestCase): + DEADBEEF_ENCODED = b'ZGVhZGJlZWY' + def test_valid_input_bytes(self): test_string = b'deadbeef' result = _urlsafe_b64encode(test_string) - self.assertEqual(result, u'ZGVhZGJlZWY') + self.assertEqual(result, self.DEADBEEF_ENCODED) def test_valid_input_unicode(self): test_string = u'deadbeef' result = _urlsafe_b64encode(test_string) - self.assertEqual(result, u'ZGVhZGJlZWY') + self.assertEqual(result, self.DEADBEEF_ENCODED) class Test__urlsafe_b64decode(unittest.TestCase): diff --git a/tests/test_jwt.py b/tests/test_jwt.py index 5f2cc90a2..5dd594f90 100644 --- a/tests/test_jwt.py +++ b/tests/test_jwt.py @@ -147,7 +147,7 @@ def test_verify_id_token_bad_tokens(self): self._check_jwt_failure('foo.bar.baz', 'Can\'t parse token') # Bad signature - jwt = 'foo.%s.baz' % crypt._urlsafe_b64encode('{"a":"b"}') + jwt = b'.'.join([b'foo', crypt._urlsafe_b64encode('{"a":"b"}'), b'baz']) self._check_jwt_failure(jwt, 'Invalid token signature') # No expiration