diff --git a/oauth2client/service_account.py b/oauth2client/service_account.py index f7c21eaed..57cfc0955 100644 --- a/oauth2client/service_account.py +++ b/oauth2client/service_account.py @@ -20,11 +20,10 @@ import json import time -import httplib2 - from oauth2client import crypt from oauth2client import GOOGLE_REVOKE_URI from oauth2client import GOOGLE_TOKEN_URI +from oauth2client import transport from oauth2client import util from oauth2client._helpers import _from_bytes from oauth2client.client import _UTCNOW @@ -32,9 +31,6 @@ from oauth2client.client import AssertionCredentials from oauth2client.client import EXPIRY_FORMAT from oauth2client.client import SERVICE_ACCOUNT -from oauth2client.transport import _apply_user_agent -from oauth2client.transport import _initialize_headers -from oauth2client.transport import clean_headers _PASSWORD_DEFAULT = 'notasecret' @@ -604,37 +600,7 @@ def authorize(self, http): h = httplib2.Http() h = credentials.authorize(h) """ - request_orig = http.request - request_auth = super( - _JWTAccessCredentials, self).authorize(http).request - - # The closure that will replace 'httplib2.Http.request'. - def new_request(uri, method='GET', body=None, headers=None, - redirections=httplib2.DEFAULT_MAX_REDIRECTS, - connection_type=None): - if 'aud' in self._kwargs: - # Preemptively refresh token, this is not done for OAuth2 - if self.access_token is None or self.access_token_expired: - self.refresh(None) - return request_auth(uri, method, body, - headers, redirections, - connection_type) - else: - # If we don't have an 'aud' (audience) claim, - # create a 1-time token with the uri root as the audience - headers = _initialize_headers(headers) - _apply_user_agent(headers, self.user_agent) - uri_root = uri.split('?', 1)[0] - token, unused_expiry = self._create_token({'aud': uri_root}) - - headers['Authorization'] = 'Bearer ' + token - return request_orig(uri, method, body, - clean_headers(headers), - redirections, connection_type) - - # Replace the request method with our own closure. - http.request = new_request - + transport.wrap_http_for_jwt_access(self, http) return http def get_access_token(self, http=None, additional_claims=None): diff --git a/oauth2client/transport.py b/oauth2client/transport.py index f160662db..8dbc60d83 100644 --- a/oauth2client/transport.py +++ b/oauth2client/transport.py @@ -195,4 +195,51 @@ def new_request(uri, method='GET', body=None, headers=None, setattr(http.request, 'credentials', credentials) +def wrap_http_for_jwt_access(credentials, http): + """Prepares an HTTP object's request method for JWT access. + + Wraps HTTP requests with logic to catch auth failures (typically + identified via a 401 status code). In the event of failure, tries + to refresh the token used and then retry the original request. + + Args: + credentials: _JWTAccessCredentials, the credentials used to identify + a service account that uses JWT access tokens. + http: httplib2.Http, an http object to be used to make + auth requests. + """ + orig_request_method = http.request + wrap_http_for_auth(credentials, http) + # The new value of ``http.request`` set by ``wrap_http_for_auth``. + authenticated_request_method = http.request + + # The closure that will replace 'httplib2.Http.request'. + def new_request(uri, method='GET', body=None, headers=None, + redirections=httplib2.DEFAULT_MAX_REDIRECTS, + connection_type=None): + if 'aud' in credentials._kwargs: + # Preemptively refresh token, this is not done for OAuth2 + if (credentials.access_token is None or + credentials.access_token_expired): + credentials.refresh(None) + return authenticated_request_method(uri, method, body, + headers, redirections, + connection_type) + else: + # If we don't have an 'aud' (audience) claim, + # create a 1-time token with the uri root as the audience + headers = _initialize_headers(headers) + _apply_user_agent(headers, credentials.user_agent) + uri_root = uri.split('?', 1)[0] + token, unused_expiry = credentials._create_token({'aud': uri_root}) + + headers['Authorization'] = 'Bearer ' + token + return orig_request_method(uri, method, body, + clean_headers(headers), + redirections, connection_type) + + # Replace the request method with our own closure. + http.request = new_request + + _CACHED_HTTP = httplib2.Http(MemoryCache()) diff --git a/tests/test_service_account.py b/tests/test_service_account.py index 7bca56b85..3bdfc0da4 100644 --- a/tests/test_service_account.py +++ b/tests/test_service_account.py @@ -527,7 +527,7 @@ def mock_request(uri, method='GET', body=None, headers=None, self.assertEqual(payload['exp'], T1_EXPIRY) self.assertEqual(uri, self.url) self.assertEqual(bearer, b'Bearer') - return (httplib2.Response({'status': '200'}), b'') + return httplib2.Response({'status': '200'}), b'' h = httplib2.Http() h.request = mock_request