diff --git a/docs/source/oauth2client.rst b/docs/source/oauth2client.rst index 25d3ce05c..65de8ac41 100644 --- a/docs/source/oauth2client.rst +++ b/docs/source/oauth2client.rst @@ -19,6 +19,7 @@ Submodules oauth2client.file oauth2client.service_account oauth2client.tools + oauth2client.transport oauth2client.util Module contents diff --git a/docs/source/oauth2client.transport.rst b/docs/source/oauth2client.transport.rst new file mode 100644 index 000000000..1c6dbb002 --- /dev/null +++ b/docs/source/oauth2client.transport.rst @@ -0,0 +1,7 @@ +oauth2client.transport module +============================= + +.. automodule:: oauth2client.transport + :members: + :undoc-members: + :show-inheritance: diff --git a/oauth2client/client.py b/oauth2client/client.py index 1d89a10f5..1ae0e88bd 100644 --- a/oauth2client/client.py +++ b/oauth2client/client.py @@ -28,7 +28,6 @@ import sys import tempfile -import httplib2 import six from six.moves import http_client from six.moves import urllib @@ -39,9 +38,9 @@ from oauth2client import GOOGLE_REVOKE_URI from oauth2client import GOOGLE_TOKEN_INFO_URI from oauth2client import GOOGLE_TOKEN_URI +from oauth2client import transport from oauth2client import util from oauth2client._helpers import _from_bytes -from oauth2client._helpers import _to_bytes from oauth2client._helpers import _urlsafe_b64decode @@ -71,9 +70,6 @@ # Constant to use for the out of band OAuth 2.0 flow. OOB_CALLBACK_URN = 'urn:ietf:wg:oauth:2.0:oob' -# Google Data client libraries may need to set this to [401, 403]. -REFRESH_STATUS_CODES = (http_client.UNAUTHORIZED,) - # The value representing user credentials. AUTHORIZED_USER = 'authorized_user' @@ -120,6 +116,12 @@ # easier testing (by replacing with a stub). _UTCNOW = datetime.datetime.utcnow +# NOTE: These names were previously defined in this module but have been +# moved into `oauth2client.transport`, +clean_headers = transport.clean_headers +MemoryCache = transport.MemoryCache +REFRESH_STATUS_CODES = transport.REFRESH_STATUS_CODES + class SETTINGS(object): """Settings namespace for globally defined values.""" @@ -177,22 +179,6 @@ class CryptoUnavailableError(Error, NotImplementedError): """Raised when a crypto library is required, but none is available.""" -class MemoryCache(object): - """httplib2 Cache implementation which only caches locally.""" - - def __init__(self): - self.cache = {} - - def get(self, key): - return self.cache.get(key) - - def set(self, key, value): - self.cache[key] = value - - def delete(self, key): - self.cache.pop(key, None) - - def _parse_expiry(expiry): if expiry and isinstance(expiry, datetime.datetime): return expiry.strftime(EXPIRY_FORMAT) @@ -451,32 +437,6 @@ def delete(self): self.release_lock() -def clean_headers(headers): - """Forces header keys and values to be strings, i.e not unicode. - - The httplib module just concats the header keys and values in a way that - may make the message header a unicode string, which, if it then tries to - contatenate to a binary request body may result in a unicode decode error. - - Args: - headers: dict, A dictionary of headers. - - Returns: - The same dictionary but with all the keys converted to strings. - """ - clean = {} - try: - for k, v in six.iteritems(headers): - if not isinstance(k, six.binary_type): - k = str(k) - if not isinstance(v, six.binary_type): - v = str(v) - clean[_to_bytes(k)] = _to_bytes(v) - except UnicodeEncodeError: - raise NonAsciiHeaderError(k, ': ', v) - return clean - - def _update_query_params(uri, params): """Updates a URI with new query parameters. @@ -494,26 +454,6 @@ def _update_query_params(uri, params): return urllib.parse.urlunparse(new_parts) -def _initialize_headers(headers): - """Creates a copy of the headers.""" - if headers is None: - headers = {} - else: - headers = dict(headers) - return headers - - -def _apply_user_agent(headers, user_agent): - """Adds a user-agent to the headers.""" - if user_agent is not None: - if 'user-agent' in headers: - headers['user-agent'] = (user_agent + ' ' + headers['user-agent']) - else: - headers['user-agent'] = user_agent - - return headers - - class OAuth2Credentials(Credentials): """Credentials object for OAuth 2.0. @@ -604,58 +544,7 @@ def authorize(self, http): that adds in the Authorization header and then calls the original version of 'request()'. """ - request_orig = http.request - - # The closure that will replace 'httplib2.Http.request'. - def new_request(uri, method='GET', body=None, headers=None, - redirections=httplib2.DEFAULT_MAX_REDIRECTS, - connection_type=None): - if not self.access_token: - logger.info('Attempting refresh to obtain ' - 'initial access_token') - self._refresh(request_orig) - - # Clone and modify the request headers to add the appropriate - # Authorization header. - headers = _initialize_headers(headers) - self.apply(headers) - _apply_user_agent(headers, self.user_agent) - - body_stream_position = None - if all(getattr(body, stream_prop, None) for stream_prop in - ('read', 'seek', 'tell')): - body_stream_position = body.tell() - - resp, content = request_orig(uri, method, body, - clean_headers(headers), - redirections, connection_type) - - # A stored token may expire between the time it is retrieved and - # the time the request is made, so we may need to try twice. - max_refresh_attempts = 2 - for refresh_attempt in range(max_refresh_attempts): - if resp.status not in REFRESH_STATUS_CODES: - break - logger.info('Refreshing due to a %s (attempt %s/%s)', - resp.status, refresh_attempt + 1, - max_refresh_attempts) - self._refresh(request_orig) - self.apply(headers) - if body_stream_position is not None: - body.seek(body_stream_position) - - resp, content = request_orig(uri, method, body, - clean_headers(headers), - redirections, connection_type) - - return (resp, content) - - # Replace the request method with our own closure. - http.request = new_request - - # Set credentials as a property of the request method. - setattr(http.request, 'credentials', self) - + transport.wrap_http_for_auth(self, http) return http def refresh(self, http): @@ -781,7 +670,7 @@ def get_access_token(self, http=None): """ if not self.access_token or self.access_token_expired: if not http: - http = httplib2.Http() + http = transport.get_http_object() self.refresh(http) return AccessTokenInfo(access_token=self.access_token, expires_in=self._expires_in()) @@ -1654,11 +1543,6 @@ def _require_crypto_or_die(): raise CryptoUnavailableError('No crypto library available') -# Only used in verify_id_token(), which is always calling to the same URI -# for the certs. -_cached_http = httplib2.Http(MemoryCache()) - - @util.positional(2) def verify_id_token(id_token, audience, http=None, cert_uri=ID_TOKEN_VERIFICATION_CERTS): @@ -1684,7 +1568,7 @@ def verify_id_token(id_token, audience, http=None, """ _require_crypto_or_die() if http is None: - http = _cached_http + http = transport.get_cached_http() resp, content = http.request(cert_uri) if resp.status == http_client.OK: @@ -2027,7 +1911,7 @@ def step1_get_device_and_user_codes(self, http=None): headers['user-agent'] = self.user_agent if http is None: - http = httplib2.Http() + http = transport.get_http_object() resp, content = http.request(self.device_uri, method='POST', body=body, headers=headers) @@ -2110,7 +1994,7 @@ def step2_exchange(self, code=None, http=None, device_flow_info=None): headers['user-agent'] = self.user_agent if http is None: - http = httplib2.Http() + http = transport.get_http_object() resp, content = http.request(self.token_uri, method='POST', body=body, headers=headers) diff --git a/oauth2client/service_account.py b/oauth2client/service_account.py index d2b2b0b21..f7c21eaed 100644 --- a/oauth2client/service_account.py +++ b/oauth2client/service_account.py @@ -27,14 +27,14 @@ from oauth2client import GOOGLE_TOKEN_URI from oauth2client import util from oauth2client._helpers import _from_bytes -from oauth2client.client import _apply_user_agent -from oauth2client.client import _initialize_headers from oauth2client.client import _UTCNOW from oauth2client.client import AccessTokenInfo from oauth2client.client import AssertionCredentials -from oauth2client.client import clean_headers from oauth2client.client import EXPIRY_FORMAT from oauth2client.client import SERVICE_ACCOUNT +from oauth2client.transport import _apply_user_agent +from oauth2client.transport import _initialize_headers +from oauth2client.transport import clean_headers _PASSWORD_DEFAULT = 'notasecret' diff --git a/oauth2client/transport.py b/oauth2client/transport.py new file mode 100644 index 000000000..f160662db --- /dev/null +++ b/oauth2client/transport.py @@ -0,0 +1,198 @@ +# Copyright 2016 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import httplib2 +import six +from six.moves import http_client + +from oauth2client._helpers import _to_bytes + + +_LOGGER = logging.getLogger(__name__) +# Properties present in file-like streams / buffers. +_STREAM_PROPERTIES = ('read', 'seek', 'tell') + +# Google Data client libraries may need to set this to [401, 403]. +REFRESH_STATUS_CODES = (http_client.UNAUTHORIZED,) + + +class MemoryCache(object): + """httplib2 Cache implementation which only caches locally.""" + + def __init__(self): + self.cache = {} + + def get(self, key): + return self.cache.get(key) + + def set(self, key, value): + self.cache[key] = value + + def delete(self, key): + self.cache.pop(key, None) + + +def get_cached_http(): + """Return an HTTP object which caches results returned. + + This is intended to be used in methods like + oauth2client.client.verify_id_token(), which calls to the same URI + to retrieve certs. + + Returns: + httplib2.Http, an HTTP object with a MemoryCache + """ + return _CACHED_HTTP + + +def get_http_object(): + """Return a new HTTP object. + + Returns: + httplib2.Http, an HTTP object. + """ + return httplib2.Http() + + +def _initialize_headers(headers): + """Creates a copy of the headers. + + Args: + headers: dict, request headers to copy. + + Returns: + dict, the copied headers or a new dictionary if the headers + were None. + """ + return {} if headers is None else dict(headers) + + +def _apply_user_agent(headers, user_agent): + """Adds a user-agent to the headers. + + Args: + headers: dict, request headers to add / modify user + agent within. + user_agent: str, the user agent to add. + + Returns: + dict, the original headers passed in, but modified if the + user agent is not None. + """ + if user_agent is not None: + if 'user-agent' in headers: + headers['user-agent'] = (user_agent + ' ' + headers['user-agent']) + else: + headers['user-agent'] = user_agent + + return headers + + +def clean_headers(headers): + """Forces header keys and values to be strings, i.e not unicode. + + The httplib module just concats the header keys and values in a way that + may make the message header a unicode string, which, if it then tries to + contatenate to a binary request body may result in a unicode decode error. + + Args: + headers: dict, A dictionary of headers. + + Returns: + The same dictionary but with all the keys converted to strings. + """ + clean = {} + try: + for k, v in six.iteritems(headers): + if not isinstance(k, six.binary_type): + k = str(k) + if not isinstance(v, six.binary_type): + v = str(v) + clean[_to_bytes(k)] = _to_bytes(v) + except UnicodeEncodeError: + from oauth2client.client import NonAsciiHeaderError + raise NonAsciiHeaderError(k, ': ', v) + return clean + + +def wrap_http_for_auth(credentials, http): + """Prepares an HTTP object's request method for auth. + + Wraps HTTP requests with logic to catch auth failures (typically + identified via a 401 status code). In the event of failure, tries + to refresh the token used and then retry the original request. + + Args: + credentials: Credentials, the credentials used to identify + the authenticated user. + http: httplib2.Http, an http object to be used to make + auth requests. + """ + orig_request_method = http.request + + # The closure that will replace 'httplib2.Http.request'. + def new_request(uri, method='GET', body=None, headers=None, + redirections=httplib2.DEFAULT_MAX_REDIRECTS, + connection_type=None): + if not credentials.access_token: + _LOGGER.info('Attempting refresh to obtain ' + 'initial access_token') + credentials._refresh(orig_request_method) + + # Clone and modify the request headers to add the appropriate + # Authorization header. + headers = _initialize_headers(headers) + credentials.apply(headers) + _apply_user_agent(headers, credentials.user_agent) + + body_stream_position = None + # Check if the body is a file-like stream. + if all(getattr(body, stream_prop, None) for stream_prop in + _STREAM_PROPERTIES): + body_stream_position = body.tell() + + resp, content = orig_request_method(uri, method, body, + clean_headers(headers), + redirections, connection_type) + + # A stored token may expire between the time it is retrieved and + # the time the request is made, so we may need to try twice. + max_refresh_attempts = 2 + for refresh_attempt in range(max_refresh_attempts): + if resp.status not in REFRESH_STATUS_CODES: + break + _LOGGER.info('Refreshing due to a %s (attempt %s/%s)', + resp.status, refresh_attempt + 1, + max_refresh_attempts) + credentials._refresh(orig_request_method) + credentials.apply(headers) + if body_stream_position is not None: + body.seek(body_stream_position) + + resp, content = orig_request_method(uri, method, body, + clean_headers(headers), + redirections, connection_type) + + return resp, content + + # Replace the request method with our own closure. + http.request = new_request + + # Set credentials as a property of the request method. + setattr(http.request, 'credentials', credentials) + + +_CACHED_HTTP = httplib2.Http(MemoryCache()) diff --git a/tests/test_client.py b/tests/test_client.py index 338c46eb1..a261d3abe 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -67,7 +67,6 @@ from oauth2client.client import GOOGLE_APPLICATION_CREDENTIALS from oauth2client.client import GoogleCredentials from oauth2client.client import HttpAccessTokenRefreshError -from oauth2client.client import MemoryCache from oauth2client.client import NonAsciiHeaderError from oauth2client.client import OAuth2Credentials from oauth2client.client import OAuth2WebServerFlow @@ -2242,18 +2241,6 @@ def test_exchange_code_and_file_for_token_fail(self): self.code, http=http) -class MemoryCacheTests(unittest2.TestCase): - - def test_get_set_delete(self): - m = MemoryCache() - self.assertEqual(None, m.get('foo')) - self.assertEqual(None, m.delete('foo')) - m.set('foo', 'bar') - self.assertEqual('bar', m.get('foo')) - m.delete('foo') - self.assertEqual(None, m.get('foo')) - - class Test__save_private_file(unittest2.TestCase): def _save_helper(self, filename): diff --git a/tests/test_jwt.py b/tests/test_jwt.py index bbcdd3c99..9c415942d 100644 --- a/tests/test_jwt.py +++ b/tests/test_jwt.py @@ -138,7 +138,7 @@ def test_verify_id_token_with_certs_uri_default_http(self): ({'status': '200'}, datafile('certs.json')), ]) - with mock.patch('oauth2client.client._cached_http', new=http): + with mock.patch('oauth2client.transport._CACHED_HTTP', new=http): contents = verify_id_token( jwt, 'some_audience_address@testing.gserviceaccount.com') diff --git a/tests/test_transport.py b/tests/test_transport.py new file mode 100644 index 000000000..e9782a864 --- /dev/null +++ b/tests/test_transport.py @@ -0,0 +1,131 @@ +# Copyright 2016 Google Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import httplib2 +import mock +import unittest2 + +from oauth2client import client +from oauth2client import transport + + +class TestMemoryCache(unittest2.TestCase): + + def test_get_set_delete(self): + cache = transport.MemoryCache() + self.assertIsNone(cache.get('foo')) + self.assertIsNone(cache.delete('foo')) + cache.set('foo', 'bar') + self.assertEqual('bar', cache.get('foo')) + cache.delete('foo') + self.assertIsNone(cache.get('foo')) + + +class Test_get_cached_http(unittest2.TestCase): + + def test_global(self): + cached_http = transport.get_cached_http() + self.assertIsInstance(cached_http, httplib2.Http) + self.assertIsInstance(cached_http.cache, transport.MemoryCache) + + def test_value(self): + cache = object() + with mock.patch('oauth2client.transport._CACHED_HTTP', new=cache): + result = transport.get_cached_http() + self.assertIs(result, cache) + + +class Test_get_http_object(unittest2.TestCase): + + @mock.patch.object(httplib2, 'Http', return_value=object()) + def test_it(self, http_klass): + result = transport.get_http_object() + self.assertEqual(result, http_klass.return_value) + + +class Test__initialize_headers(unittest2.TestCase): + + def test_null(self): + result = transport._initialize_headers(None) + self.assertEqual(result, {}) + + def test_copy(self): + headers = {'a': 1, 'b': 2} + result = transport._initialize_headers(headers) + self.assertEqual(result, headers) + self.assertIsNot(result, headers) + + +class Test__apply_user_agent(unittest2.TestCase): + + def test_null(self): + headers = object() + result = transport._apply_user_agent(headers, None) + self.assertIs(result, headers) + + def test_new_agent(self): + headers = {} + user_agent = 'foo' + result = transport._apply_user_agent(headers, user_agent) + self.assertIs(result, headers) + self.assertEqual(result, {'user-agent': user_agent}) + + def test_append(self): + orig_agent = 'bar' + headers = {'user-agent': orig_agent} + user_agent = 'baz' + result = transport._apply_user_agent(headers, user_agent) + self.assertIs(result, headers) + final_agent = user_agent + ' ' + orig_agent + self.assertEqual(result, {'user-agent': final_agent}) + + +class Test_clean_headers(unittest2.TestCase): + + def test_no_modify(self): + headers = {b'key': b'val'} + result = transport.clean_headers(headers) + self.assertIsNot(result, headers) + self.assertEqual(result, headers) + + def test_cast_unicode(self): + headers = {u'key': u'val'} + header_bytes = {b'key': b'val'} + result = transport.clean_headers(headers) + self.assertIsNot(result, headers) + self.assertEqual(result, header_bytes) + + def test_unicode_failure(self): + headers = {u'key': u'\u2603'} + with self.assertRaises(client.NonAsciiHeaderError): + transport.clean_headers(headers) + + def test_cast_object(self): + headers = {b'key': True} + header_str = {b'key': b'True'} + result = transport.clean_headers(headers) + self.assertIsNot(result, headers) + self.assertEqual(result, header_str) + + +class Test_wrap_http_for_auth(unittest2.TestCase): + + def test_wrap(self): + credentials = object() + http = mock.Mock() + http.request = orig_req_method = object() + result = transport.wrap_http_for_auth(credentials, http) + self.assertIsNone(result) + self.assertNotEqual(http.request, orig_req_method) + self.assertIs(http.request.credentials, credentials)