Skip to content
This repository has been archived by the owner on Jan 18, 2025. It is now read-only.

Commit

Permalink
Merge pull request #272 from dhermes/add-from-bytes
Browse files Browse the repository at this point in the history
Adding _from_bytes helpers as a foil for _to_bytes.
  • Loading branch information
dhermes committed Aug 19, 2015
2 parents 904d228 + abca14e commit 043e066
Show file tree
Hide file tree
Showing 9 changed files with 102 additions and 59 deletions.
21 changes: 21 additions & 0 deletions oauth2client/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,27 @@ def _to_bytes(value, encoding='ascii'):
raise ValueError('%r could not be converted to bytes' % (value,))


def _from_bytes(value):
"""Converts bytes to a string value, if necessary.
Args:
value: The string/bytes value to be converted.
Returns:
The original value converted to unicode (if bytes) or as passed in
if it started out as unicode.
Raises:
ValueError if the value could not be converted to unicode.
"""
result = (value.decode('utf-8')
if isinstance(value, six.binary_type) else value)
if isinstance(result, six.text_type):
return result
else:
raise ValueError('%r could not be converted to unicode' % (value,))


def _urlsafe_b64encode(raw_bytes):
raw_bytes = _to_bytes(raw_bytes, encoding='utf-8')
return base64.urlsafe_b64encode(raw_bytes).rstrip(b'=')
Expand Down
60 changes: 25 additions & 35 deletions oauth2client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from oauth2client import GOOGLE_REVOKE_URI
from oauth2client import GOOGLE_TOKEN_URI
from oauth2client import GOOGLE_TOKEN_INFO_URI
from oauth2client._helpers import _from_bytes
from oauth2client._helpers import _to_bytes
from oauth2client._helpers import _urlsafe_b64decode
from oauth2client import clientsecrets
Expand Down Expand Up @@ -269,32 +270,32 @@ def to_json(self):

@classmethod
def new_from_json(cls, s):
"""Utility class method to instantiate a Credentials subclass from a JSON
representation produced by to_json().
"""Utility class method to instantiate a Credentials subclass from JSON.
Expects the JSON string to have been produced by to_json().
Args:
s: string, JSON from to_json().
s: string or bytes, JSON from to_json().
Returns:
An instance of the subclass of Credentials that was serialized with
to_json().
"""
if isinstance(s, bytes):
s = s.decode('utf-8')
data = json.loads(s)
json_string_as_unicode = _from_bytes(s)
data = json.loads(json_string_as_unicode)
# Find and call the right classmethod from_json() to restore the object.
module = data['_module']
module_name = data['_module']
try:
m = __import__(module)
module_obj = __import__(module_name)
except ImportError:
# In case there's an object from the old package structure, update it
module = module.replace('.googleapiclient', '')
m = __import__(module)
module_name = module_name.replace('.googleapiclient', '')
module_obj = __import__(module_name)

m = __import__(module, fromlist=module.split('.')[:-1])
kls = getattr(m, data['_class'])
module_obj = __import__(module_name, fromlist=module_name.split('.')[:-1])
kls = getattr(module_obj, data['_class'])
from_json = getattr(kls, 'from_json')
return from_json(s)
return from_json(json_string_as_unicode)

@classmethod
def from_json(cls, unused_data):
Expand Down Expand Up @@ -673,8 +674,7 @@ def from_json(cls, s):
Returns:
An instance of a Credentials subclass.
"""
if isinstance(s, bytes):
s = s.decode('utf-8')
s = _from_bytes(s)
data = json.loads(s)
if (data.get('token_expiry') and
not isinstance(data['token_expiry'], datetime.datetime)):
Expand Down Expand Up @@ -845,8 +845,7 @@ def _do_refresh_request(self, http_request):
logger.info('Refreshing access_token')
resp, content = http_request(
self.token_uri, method='POST', body=body, headers=headers)
if isinstance(content, bytes):
content = content.decode('utf-8')
content = _from_bytes(content)
if resp.status == 200:
d = json.loads(content)
self.token_response = d
Expand Down Expand Up @@ -905,16 +904,12 @@ def _do_revoke(self, http_request, token):
query_params = {'token': token}
token_revoke_uri = _update_query_params(self.revoke_uri, query_params)
resp, content = http_request(token_revoke_uri)

if isinstance(content, bytes):
content = content.decode('utf-8')

if resp.status == 200:
self.invalid = True
else:
error_msg = 'Invalid response %s.' % resp.status
try:
d = json.loads(content)
d = json.loads(_from_bytes(content))
if 'error' in d:
error_msg = d['error']
except (TypeError, ValueError):
Expand Down Expand Up @@ -949,10 +944,7 @@ def _do_retrieve_scopes(self, http_request, token):
query_params = {'access_token': token, 'fields': 'scope'}
token_info_uri = _update_query_params(self.token_info_uri, query_params)
resp, content = http_request(token_info_uri)

if six.PY3 and isinstance(content, bytes):
content = content.decode('utf-8')

content = _from_bytes(content)
if resp.status == 200:
d = json.loads(content)
self.scopes = set(util.string_to_scopes(d.get('scope', '')))
Expand Down Expand Up @@ -1018,9 +1010,7 @@ def __init__(self, access_token, user_agent, revoke_uri=None):

@classmethod
def from_json(cls, s):
if isinstance(s, bytes):
s = s.decode('utf-8')
data = json.loads(s)
data = json.loads(_from_bytes(s))
retval = AccessTokenCredentials(
data['access_token'],
data['user_agent'])
Expand Down Expand Up @@ -1612,7 +1602,7 @@ def __init__(self,

@classmethod
def from_json(cls, s):
data = json.loads(s)
data = json.loads(_from_bytes(s))
retval = SignedJwtAssertionCredentials(
data['service_account_name'],
base64.b64decode(data['private_key']),
Expand Down Expand Up @@ -1675,9 +1665,8 @@ def verify_id_token(id_token, audience, http=None,
http = _cached_http

resp, content = http.request(cert_uri)

if resp.status == 200:
certs = json.loads(content.decode('utf-8'))
certs = json.loads(_from_bytes(content))
return crypt.verify_signed_jwt_with_certs(id_token, certs, audience)
else:
raise VerifyJwtTokenError('Status code: %d' % resp.status)
Expand All @@ -1703,7 +1692,7 @@ def _extract_id_token(id_token):
raise VerifyJwtTokenError(
'Wrong number of segments in token: %s' % id_token)

return json.loads(_urlsafe_b64decode(segments[1]).decode('utf-8'))
return json.loads(_from_bytes(_urlsafe_b64decode(segments[1])))


def _parse_exchange_token_response(content):
Expand All @@ -1720,12 +1709,12 @@ def _parse_exchange_token_response(content):
i.e. {}. That basically indicates a failure.
"""
resp = {}
content = _from_bytes(content)
try:
resp = json.loads(content.decode('utf-8'))
resp = json.loads(content)
except Exception:
# different JSON libs raise different exceptions,
# so we just do a catch-all here
content = content.decode('utf-8')
resp = dict(urllib.parse.parse_qsl(content))

# some providers respond with 'expires', others with 'expires_in'
Expand Down Expand Up @@ -2000,6 +1989,7 @@ def step1_get_device_and_user_codes(self, http=None):

resp, content = http.request(self.device_uri, method='POST', body=body,
headers=headers)
content = _from_bytes(content)
if resp.status == 200:
try:
flow_info = json.loads(content)
Expand Down
3 changes: 2 additions & 1 deletion oauth2client/crypt.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import logging
import time

from oauth2client._helpers import _from_bytes
from oauth2client._helpers import _json_encode
from oauth2client._helpers import _to_bytes
from oauth2client._helpers import _urlsafe_b64decode
Expand Down Expand Up @@ -124,7 +125,7 @@ def verify_signed_jwt_with_certs(jwt, certs, audience):
# Parse token.
json_body = _urlsafe_b64decode(segments[1])
try:
parsed = json.loads(json_body.decode('utf-8'))
parsed = json.loads(_from_bytes(json_body))
except:
raise AppIdentityError('Can\'t parse token: %s' % json_body)

Expand Down
4 changes: 3 additions & 1 deletion oauth2client/gce.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import logging
from six.moves import urllib

from oauth2client._helpers import _from_bytes
from oauth2client import util
from oauth2client.client import AccessTokenRefreshError
from oauth2client.client import AssertionCredentials
Expand Down Expand Up @@ -63,7 +64,7 @@ def __init__(self, scope, **kwargs):

@classmethod
def from_json(cls, json_data):
data = json.loads(json_data)
data = json.loads(_from_bytes(json_data))
return AppAssertionCredentials(data['scope'])

def _refresh(self, http_request):
Expand All @@ -81,6 +82,7 @@ def _refresh(self, http_request):
query = '?scope=%s' % urllib.parse.quote(self.scope, '')
uri = META.replace('{?scope}', query)
response, content = http_request(uri)
content = _from_bytes(content)
if response.status == 200:
try:
d = json.loads(content)
Expand Down
1 change: 0 additions & 1 deletion oauth2client/service_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
"""

import base64
import six
import time

from pyasn1.codec.ber import decoder
Expand Down
17 changes: 17 additions & 0 deletions tests/test__helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import unittest

from oauth2client._helpers import _from_bytes
from oauth2client._helpers import _json_encode
from oauth2client._helpers import _parse_pem_key
from oauth2client._helpers import _to_bytes
Expand Down Expand Up @@ -66,6 +67,22 @@ def test_with_nonstring_type(self):
self.assertRaises(ValueError, _to_bytes, value)


class Test__from_bytes(unittest.TestCase):

def test_with_unicode(self):
value = u'bytes-val'
self.assertEqual(_from_bytes(value), value)

def test_with_bytes(self):
value = b'string-val'
decoded_value = u'string-val'
self.assertEqual(_from_bytes(value), decoded_value)

def test_with_nonstring_type(self):
value = object()
self.assertRaises(ValueError, _from_bytes, value)


class Test__urlsafe_b64encode(unittest.TestCase):

DEADBEEF_ENCODED = b'ZGVhZGJlZWY'
Expand Down
14 changes: 7 additions & 7 deletions tests/test_appengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
import tempfile
import time
import unittest
import urllib
import urlparse

from six.moves import urllib

import dev_appserver
dev_appserver.fix_sys_path()
Expand Down Expand Up @@ -554,7 +554,7 @@ def test_required(self):
self.assertEqual(self.decorator.credentials, None)
response = self.app.get('http://localhost/foo_path')
self.assertTrue(response.status.startswith('302'))
q = urlparse.parse_qs(response.headers['Location'].split('?', 1)[1])
q = urllib.parse.parse_qs(response.headers['Location'].split('?', 1)[1])
self.assertEqual('http://localhost/oauth2callback', q['redirect_uri'][0])
self.assertEqual('foo_client_id', q['client_id'][0])
self.assertEqual('foo_scope bar_scope', q['scope'][0])
Expand All @@ -575,10 +575,10 @@ def test_required(self):
self.assertEqual('http://localhost/foo_path', parts[0])
self.assertEqual(None, self.decorator.credentials)
if self.decorator._token_response_param:
response_query = urlparse.parse_qs(parts[1])
response_query = urllib.parse.parse_qs(parts[1])
response = response_query[self.decorator._token_response_param][0]
self.assertEqual(Http2Mock.content,
json.loads(urllib.unquote(response)))
json.loads(urllib.parse.unquote(response)))
self.assertEqual(self.decorator.flow, self.decorator._tls.flow)
self.assertEqual(self.decorator.credentials,
self.decorator._tls.credentials)
Expand Down Expand Up @@ -609,7 +609,7 @@ def test_required(self):
# Invalid Credentials should start the OAuth dance again.
response = self.app.get('/foo_path')
self.assertTrue(response.status.startswith('302'))
q = urlparse.parse_qs(response.headers['Location'].split('?', 1)[1])
q = urllib.parse.parse_qs(response.headers['Location'].split('?', 1)[1])
self.assertEqual('http://localhost/oauth2callback', q['redirect_uri'][0])

def test_storage_delete(self):
Expand Down Expand Up @@ -654,7 +654,7 @@ def test_aware(self):
self.assertEqual('200 OK', response.status)
self.assertEqual(False, self.decorator.has_credentials())
url = self.decorator.authorize_url()
q = urlparse.parse_qs(url.split('?', 1)[1])
q = urllib.parse.parse_qs(url.split('?', 1)[1])
self.assertEqual('http://localhost/oauth2callback', q['redirect_uri'][0])
self.assertEqual('foo_client_id', q['client_id'][0])
self.assertEqual('foo_scope bar_scope', q['scope'][0])
Expand Down
37 changes: 25 additions & 12 deletions tests/test_gce.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@

__author__ = '[email protected] (Joe Gregorio)'

import json
from six.moves import urllib
import unittest

import httplib2
import mock

from oauth2client._helpers import _to_bytes
from oauth2client.client import AccessTokenRefreshError
from oauth2client.client import Credentials
from oauth2client.client import save_to_well_known_file
Expand All @@ -33,22 +36,32 @@

class AssertionCredentialsTests(unittest.TestCase):

def test_good_refresh(self):
def _refresh_success_helper(self, bytes_response=False):
access_token = u'this-is-a-token'
return_val = json.dumps({u'accessToken': access_token})
if bytes_response:
return_val = _to_bytes(return_val)
http = mock.MagicMock()
http.request = mock.MagicMock(
return_value=(mock.Mock(status=200),
'{"accessToken": "this-is-a-token"}'))
return_value=(mock.Mock(status=200), return_val))

c = AppAssertionCredentials(scope=['http://example.com/a',
'http://example.com/b'])
self.assertEquals(None, c.access_token)
c.refresh(http)
self.assertEquals('this-is-a-token', c.access_token)
scopes = ['http://example.com/a', 'http://example.com/b']
credentials = AppAssertionCredentials(scope=scopes)
self.assertEquals(None, credentials.access_token)
credentials.refresh(http)
self.assertEquals(access_token, credentials.access_token)

http.request.assert_called_once_with(
'http://metadata.google.internal/0.1/meta-data/service-accounts/'
'default/acquire'
'?scope=http%3A%2F%2Fexample.com%2Fa%20http%3A%2F%2Fexample.com%2Fb')
base_metadata_uri = ('http://metadata.google.internal/0.1/meta-data/'
'service-accounts/default/acquire')
escaped_scopes = urllib.parse.quote(' '.join(scopes), safe='')
request_uri = base_metadata_uri + '?scope=' + escaped_scopes
http.request.assert_called_once_with(request_uri)

def test_refresh_success(self):
self._refresh_success_helper(bytes_response=False)

def test_refresh_success_bytes(self):
self._refresh_success_helper(bytes_response=True)

def test_fail_refresh(self):
http = mock.MagicMock()
Expand Down
4 changes: 2 additions & 2 deletions tests/test_oauth2client.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,8 +801,8 @@ def test_no_unicode_in_request_params(self):
http = credentials.authorize(http)
http.request(u'http://example.com', method=u'GET', headers={u'foo': u'bar'})
for k, v in six.iteritems(http.headers):
self.assertEqual(six.binary_type, type(k))
self.assertEqual(six.binary_type, type(v))
self.assertTrue(isinstance(k, six.binary_type))
self.assertTrue(isinstance(v, six.binary_type))

# Test again with unicode strings that can't simply be converted to ASCII.
try:
Expand Down

0 comments on commit 043e066

Please sign in to comment.