From 30c342a437a137c82444364073f4e4afebeb891a Mon Sep 17 00:00:00 2001 From: Danny Hermes Date: Tue, 5 Jan 2016 00:02:20 -0800 Subject: [PATCH] Factor out usage of utcnow() in client. This is to enable better stubs in testing and eliminate two sleep() statements in unit tests. (The philosophy is "unit tests should be fast".) --- oauth2client/client.py | 17 +++--- tests/test_client.py | 102 +++++++++++++++++++++++++--------- tests/test_service_account.py | 92 ++++++++++++++++++++++++------ 3 files changed, 160 insertions(+), 51 deletions(-) diff --git a/oauth2client/client.py b/oauth2client/client.py index 20a59c482..155d3217c 100644 --- a/oauth2client/client.py +++ b/oauth2client/client.py @@ -117,6 +117,10 @@ _METADATA_FLAVOR_HEADER = 'Metadata-Flavor' _DESIRED_METADATA_FLAVOR = 'Google' +# Expose utcnow() at module level to allow for +# easier testing (by replacing with a stub). +_UTCNOW = datetime.datetime.utcnow + class SETTINGS(object): """Settings namespace for globally defined values.""" @@ -737,7 +741,7 @@ def access_token_expired(self): if not self.token_expiry: return False - now = datetime.datetime.utcnow() + now = _UTCNOW() if now >= self.token_expiry: logger.info('access_token is expired. Now: %s, token_expiry: %s', now, self.token_expiry) @@ -780,7 +784,7 @@ def _expires_in(self): valid; we just don't know anything about it. """ if self.token_expiry: - now = datetime.datetime.utcnow() + now = _UTCNOW() if self.token_expiry > now: time_delta = self.token_expiry - now # TODO(orestica): return time_delta.total_seconds() @@ -881,8 +885,8 @@ def _do_refresh_request(self, http_request): self.access_token = d['access_token'] self.refresh_token = d.get('refresh_token', self.refresh_token) if 'expires_in' in d: - self.token_expiry = datetime.timedelta( - seconds=int(d['expires_in'])) + datetime.datetime.utcnow() + delta = datetime.timedelta(seconds=int(d['expires_in'])) + self.token_expiry = delta + _UTCNOW() else: self.token_expiry = None if 'id_token' in d: @@ -2149,9 +2153,8 @@ def step2_exchange(self, code=None, http=None, device_flow_info=None): "reauthenticating with approval_prompt='force'.") token_expiry = None if 'expires_in' in d: - token_expiry = ( - datetime.datetime.utcnow() + - datetime.timedelta(seconds=int(d['expires_in']))) + delta = datetime.timedelta(seconds=int(d['expires_in'])) + token_expiry = delta + _UTCNOW() extracted_id_token = None if 'id_token' in d: diff --git a/tests/test_client.py b/tests/test_client.py index c3cbdab28..ef9d633cc 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -21,12 +21,12 @@ import base64 import contextlib +import copy import datetime import json import os import socket import sys -import time import mock import six @@ -841,11 +841,27 @@ def test_no_unicode_in_request_params(self): instance = OAuth2Credentials.from_json(self.credentials.to_json()) self.assertEqual('foobar', instance.token_response) - def test_get_access_token(self): - S = 2 # number of seconds in which the token expires - token_response_first = {'access_token': 'first_token', 'expires_in': S} - token_response_second = {'access_token': 'second_token', - 'expires_in': S} + @mock.patch('oauth2client.client._UTCNOW') + def test_get_access_token(self, utcnow): + # Configure the patch. + seconds = 11 + NOW = datetime.datetime(1992, 12, 31, second=seconds) + utcnow.return_value = NOW + + lifetime = 2 # number of seconds in which the token expires + EXPIRY_TIME = datetime.datetime(1992, 12, 31, + second=seconds + lifetime) + + token1 = u'first_token' + token_response_first = { + 'access_token': token1, + 'expires_in': lifetime, + } + token2 = u'second_token' + token_response_second = { + 'access_token': token2, + 'expires_in': lifetime, + } http = HttpMockSequence([ ({'status': '200'}, json.dumps(token_response_first).encode( 'utf-8')), @@ -853,27 +869,61 @@ def test_get_access_token(self): 'utf-8')), ]) - token = self.credentials.get_access_token(http=http) - self.assertEqual('first_token', token.access_token) - self.assertEqual(S - 1, token.expires_in) - self.assertFalse(self.credentials.access_token_expired) - self.assertEqual(token_response_first, self.credentials.token_response) - - token = self.credentials.get_access_token(http=http) - self.assertEqual('first_token', token.access_token) - self.assertEqual(S - 1, token.expires_in) - self.assertFalse(self.credentials.access_token_expired) - self.assertEqual(token_response_first, self.credentials.token_response) - - time.sleep(S + 0.5) # some margin to avoid flakiness - self.assertTrue(self.credentials.access_token_expired) - - token = self.credentials.get_access_token(http=http) - self.assertEqual('second_token', token.access_token) - self.assertEqual(S - 1, token.expires_in) - self.assertFalse(self.credentials.access_token_expired) + # Use the current credentials but unset the expiry and + # the access token. + credentials = copy.deepcopy(self.credentials) + credentials.access_token = None + credentials.token_expiry = None + + # Get Access Token, First attempt. + self.assertEqual(credentials.access_token, None) + self.assertFalse(credentials.access_token_expired) + self.assertEqual(credentials.token_expiry, None) + token = credentials.get_access_token(http=http) + self.assertEqual(credentials.token_expiry, EXPIRY_TIME) + self.assertEqual(token1, token.access_token) + self.assertEqual(lifetime, token.expires_in) + self.assertEqual(token_response_first, credentials.token_response) + # Two utcnow calls are expected: + # - get_access_token() -> _do_refresh_request (setting expires in) + # - get_access_token() -> _expires_in() + expected_utcnow_calls = [mock.call()] * 2 + self.assertEqual(expected_utcnow_calls, utcnow.mock_calls) + + # Get Access Token, Second Attempt (not expired) + self.assertEqual(credentials.access_token, token1) + self.assertFalse(credentials.access_token_expired) + token = credentials.get_access_token(http=http) + # Make sure no refresh occurred since the token was not expired. + self.assertEqual(token1, token.access_token) + self.assertEqual(lifetime, token.expires_in) + self.assertEqual(token_response_first, credentials.token_response) + # Three more utcnow calls are expected: + # - access_token_expired + # - get_access_token() -> access_token_expired + # - get_access_token -> _expires_in + expected_utcnow_calls = [mock.call()] * (2 + 3) + self.assertEqual(expected_utcnow_calls, utcnow.mock_calls) + + # Get Access Token, Third Attempt (force expiration) + self.assertEqual(credentials.access_token, token1) + credentials.token_expiry = NOW # Manually force expiry. + self.assertTrue(credentials.access_token_expired) + token = credentials.get_access_token(http=http) + # Make sure refresh occurred since the token was not expired. + self.assertEqual(token2, token.access_token) + self.assertEqual(lifetime, token.expires_in) + self.assertFalse(credentials.access_token_expired) self.assertEqual(token_response_second, - self.credentials.token_response) + credentials.token_response) + # Five more utcnow calls are expected: + # - access_token_expired + # - get_access_token -> access_token_expired + # - get_access_token -> _do_refresh_request + # - get_access_token -> _expires_in + # - access_token_expired + expected_utcnow_calls = [mock.call()] * (2 + 3 + 5) + self.assertEqual(expected_utcnow_calls, utcnow.mock_calls) def test_has_scopes(self): self.assertTrue(self.credentials.has_scopes('foo')) diff --git a/tests/test_service_account.py b/tests/test_service_account.py index 1ba0cae09..09d6234ef 100644 --- a/tests/test_service_account.py +++ b/tests/test_service_account.py @@ -17,12 +17,14 @@ Unit tests for service account credentials implemented using RSA. """ +import datetime import json import os import rsa -import time import unittest +import mock + from .http_mock import HttpMockSequence from oauth2client.service_account import _ServiceAccountCredentials @@ -88,11 +90,28 @@ def test_create_scoped(self): _ServiceAccountCredentials)) self.assertEqual('dummy_scope', new_credentials._scopes) - def test_access_token(self): - S = 2 # number of seconds in which the token expires - token_response_first = {'access_token': 'first_token', 'expires_in': S} - token_response_second = {'access_token': 'second_token', - 'expires_in': S} + @mock.patch('oauth2client.client._UTCNOW') + @mock.patch('rsa.pkcs1.sign', return_value=b'signed-value') + def test_access_token(self, sign_func, utcnow): + # Configure the patch. + seconds = 11 + NOW = datetime.datetime(1992, 12, 31, second=seconds) + utcnow.return_value = NOW + + lifetime = 2 # number of seconds in which the token expires + EXPIRY_TIME = datetime.datetime(1992, 12, 31, + second=seconds + lifetime) + + token1 = u'first_token' + token_response_first = { + 'access_token': token1, + 'expires_in': lifetime, + } + token2 = u'second_token' + token_response_second = { + 'access_token': token2, + 'expires_in': lifetime, + } http = HttpMockSequence([ ({'status': '200'}, json.dumps(token_response_first).encode('utf-8')), @@ -100,27 +119,64 @@ def test_access_token(self): json.dumps(token_response_second).encode('utf-8')), ]) - token = self.credentials.get_access_token(http=http) - self.assertEqual('first_token', token.access_token) - self.assertEqual(S - 1, token.expires_in) + # Get Access Token, First attempt. + self.assertEqual(self.credentials.access_token, None) self.assertFalse(self.credentials.access_token_expired) - self.assertEqual(token_response_first, self.credentials.token_response) - + self.assertEqual(self.credentials.token_expiry, None) token = self.credentials.get_access_token(http=http) - self.assertEqual('first_token', token.access_token) - self.assertEqual(S - 1, token.expires_in) + self.assertEqual(self.credentials.token_expiry, EXPIRY_TIME) + self.assertEqual(token1, token.access_token) + self.assertEqual(lifetime, token.expires_in) + self.assertEqual(token_response_first, + self.credentials.token_response) + # Two utcnow calls are expected: + # - get_access_token() -> _do_refresh_request (setting expires in) + # - get_access_token() -> _expires_in() + expected_utcnow_calls = [mock.call()] * 2 + self.assertEqual(expected_utcnow_calls, utcnow.mock_calls) + # One rsa.pkcs1.sign expected: Actual refresh was needed. + self.assertEqual(len(sign_func.mock_calls), 1) + + # Get Access Token, Second Attempt (not expired) + self.assertEqual(self.credentials.access_token, token1) self.assertFalse(self.credentials.access_token_expired) + token = self.credentials.get_access_token(http=http) + # Make sure no refresh occurred since the token was not expired. + self.assertEqual(token1, token.access_token) + self.assertEqual(lifetime, token.expires_in) self.assertEqual(token_response_first, self.credentials.token_response) - - time.sleep(S + 0.5) # some margin to avoid flakiness + # Three more utcnow calls are expected: + # - access_token_expired + # - get_access_token() -> access_token_expired + # - get_access_token -> _expires_in + expected_utcnow_calls = [mock.call()] * (2 + 3) + self.assertEqual(expected_utcnow_calls, utcnow.mock_calls) + # No rsa.pkcs1.sign expected: the token was not expired. + self.assertEqual(len(sign_func.mock_calls), 1 + 0) + + # Get Access Token, Third Attempt (force expiration) + self.assertEqual(self.credentials.access_token, token1) + self.credentials.token_expiry = NOW # Manually force expiry. self.assertTrue(self.credentials.access_token_expired) - token = self.credentials.get_access_token(http=http) - self.assertEqual('second_token', token.access_token) - self.assertEqual(S - 1, token.expires_in) + # Make sure refresh occurred since the token was not expired. + self.assertEqual(token2, token.access_token) + self.assertEqual(lifetime, token.expires_in) self.assertFalse(self.credentials.access_token_expired) self.assertEqual(token_response_second, self.credentials.token_response) + # Five more utcnow calls are expected: + # - access_token_expired + # - get_access_token -> access_token_expired + # - get_access_token -> _do_refresh_request + # - get_access_token -> _expires_in + # - access_token_expired + expected_utcnow_calls = [mock.call()] * (2 + 3 + 5) + self.assertEqual(expected_utcnow_calls, utcnow.mock_calls) + # One more rsa.pkcs1.sign expected: Actual refresh was needed. + self.assertEqual(len(sign_func.mock_calls), 1 + 0 + 1) + + self.assertEqual(self.credentials.access_token, token2) if __name__ == '__main__': # pragma: NO COVER