Skip to content
This repository has been archived by the owner on Nov 5, 2019. It is now read-only.

Commit

Permalink
Factor out usage of utcnow() in client.
Browse files Browse the repository at this point in the history
This is to enable better stubs in testing and eliminate
two sleep() statements in unit tests. (The philosophy
is "unit tests should be fast".)
  • Loading branch information
dhermes committed Jan 5, 2016
1 parent 2421420 commit 30c342a
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 51 deletions.
17 changes: 10 additions & 7 deletions oauth2client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,10 @@
_METADATA_FLAVOR_HEADER = 'Metadata-Flavor'
_DESIRED_METADATA_FLAVOR = 'Google'

# Expose utcnow() at module level to allow for
# easier testing (by replacing with a stub).
_UTCNOW = datetime.datetime.utcnow


class SETTINGS(object):
"""Settings namespace for globally defined values."""
Expand Down Expand Up @@ -737,7 +741,7 @@ def access_token_expired(self):
if not self.token_expiry:
return False

now = datetime.datetime.utcnow()
now = _UTCNOW()
if now >= self.token_expiry:
logger.info('access_token is expired. Now: %s, token_expiry: %s',
now, self.token_expiry)
Expand Down Expand Up @@ -780,7 +784,7 @@ def _expires_in(self):
valid; we just don't know anything about it.
"""
if self.token_expiry:
now = datetime.datetime.utcnow()
now = _UTCNOW()
if self.token_expiry > now:
time_delta = self.token_expiry - now
# TODO(orestica): return time_delta.total_seconds()
Expand Down Expand Up @@ -881,8 +885,8 @@ def _do_refresh_request(self, http_request):
self.access_token = d['access_token']
self.refresh_token = d.get('refresh_token', self.refresh_token)
if 'expires_in' in d:
self.token_expiry = datetime.timedelta(
seconds=int(d['expires_in'])) + datetime.datetime.utcnow()
delta = datetime.timedelta(seconds=int(d['expires_in']))
self.token_expiry = delta + _UTCNOW()
else:
self.token_expiry = None
if 'id_token' in d:
Expand Down Expand Up @@ -2149,9 +2153,8 @@ def step2_exchange(self, code=None, http=None, device_flow_info=None):
"reauthenticating with approval_prompt='force'.")
token_expiry = None
if 'expires_in' in d:
token_expiry = (
datetime.datetime.utcnow() +
datetime.timedelta(seconds=int(d['expires_in'])))
delta = datetime.timedelta(seconds=int(d['expires_in']))
token_expiry = delta + _UTCNOW()

extracted_id_token = None
if 'id_token' in d:
Expand Down
102 changes: 76 additions & 26 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@

import base64
import contextlib
import copy
import datetime
import json
import os
import socket
import sys
import time

import mock
import six
Expand Down Expand Up @@ -841,39 +841,89 @@ def test_no_unicode_in_request_params(self):
instance = OAuth2Credentials.from_json(self.credentials.to_json())
self.assertEqual('foobar', instance.token_response)

def test_get_access_token(self):
S = 2 # number of seconds in which the token expires
token_response_first = {'access_token': 'first_token', 'expires_in': S}
token_response_second = {'access_token': 'second_token',
'expires_in': S}
@mock.patch('oauth2client.client._UTCNOW')
def test_get_access_token(self, utcnow):
# Configure the patch.
seconds = 11
NOW = datetime.datetime(1992, 12, 31, second=seconds)
utcnow.return_value = NOW

lifetime = 2 # number of seconds in which the token expires
EXPIRY_TIME = datetime.datetime(1992, 12, 31,
second=seconds + lifetime)

token1 = u'first_token'
token_response_first = {
'access_token': token1,
'expires_in': lifetime,
}
token2 = u'second_token'
token_response_second = {
'access_token': token2,
'expires_in': lifetime,
}
http = HttpMockSequence([
({'status': '200'}, json.dumps(token_response_first).encode(
'utf-8')),
({'status': '200'}, json.dumps(token_response_second).encode(
'utf-8')),
])

token = self.credentials.get_access_token(http=http)
self.assertEqual('first_token', token.access_token)
self.assertEqual(S - 1, token.expires_in)
self.assertFalse(self.credentials.access_token_expired)
self.assertEqual(token_response_first, self.credentials.token_response)

token = self.credentials.get_access_token(http=http)
self.assertEqual('first_token', token.access_token)
self.assertEqual(S - 1, token.expires_in)
self.assertFalse(self.credentials.access_token_expired)
self.assertEqual(token_response_first, self.credentials.token_response)

time.sleep(S + 0.5) # some margin to avoid flakiness
self.assertTrue(self.credentials.access_token_expired)

token = self.credentials.get_access_token(http=http)
self.assertEqual('second_token', token.access_token)
self.assertEqual(S - 1, token.expires_in)
self.assertFalse(self.credentials.access_token_expired)
# Use the current credentials but unset the expiry and
# the access token.
credentials = copy.deepcopy(self.credentials)
credentials.access_token = None
credentials.token_expiry = None

# Get Access Token, First attempt.
self.assertEqual(credentials.access_token, None)
self.assertFalse(credentials.access_token_expired)
self.assertEqual(credentials.token_expiry, None)
token = credentials.get_access_token(http=http)
self.assertEqual(credentials.token_expiry, EXPIRY_TIME)
self.assertEqual(token1, token.access_token)
self.assertEqual(lifetime, token.expires_in)
self.assertEqual(token_response_first, credentials.token_response)
# Two utcnow calls are expected:
# - get_access_token() -> _do_refresh_request (setting expires in)
# - get_access_token() -> _expires_in()
expected_utcnow_calls = [mock.call()] * 2
self.assertEqual(expected_utcnow_calls, utcnow.mock_calls)

# Get Access Token, Second Attempt (not expired)
self.assertEqual(credentials.access_token, token1)
self.assertFalse(credentials.access_token_expired)
token = credentials.get_access_token(http=http)
# Make sure no refresh occurred since the token was not expired.
self.assertEqual(token1, token.access_token)
self.assertEqual(lifetime, token.expires_in)
self.assertEqual(token_response_first, credentials.token_response)
# Three more utcnow calls are expected:
# - access_token_expired
# - get_access_token() -> access_token_expired
# - get_access_token -> _expires_in
expected_utcnow_calls = [mock.call()] * (2 + 3)
self.assertEqual(expected_utcnow_calls, utcnow.mock_calls)

# Get Access Token, Third Attempt (force expiration)
self.assertEqual(credentials.access_token, token1)
credentials.token_expiry = NOW # Manually force expiry.
self.assertTrue(credentials.access_token_expired)
token = credentials.get_access_token(http=http)
# Make sure refresh occurred since the token was not expired.
self.assertEqual(token2, token.access_token)
self.assertEqual(lifetime, token.expires_in)
self.assertFalse(credentials.access_token_expired)
self.assertEqual(token_response_second,
self.credentials.token_response)
credentials.token_response)
# Five more utcnow calls are expected:
# - access_token_expired
# - get_access_token -> access_token_expired
# - get_access_token -> _do_refresh_request
# - get_access_token -> _expires_in
# - access_token_expired
expected_utcnow_calls = [mock.call()] * (2 + 3 + 5)
self.assertEqual(expected_utcnow_calls, utcnow.mock_calls)

def test_has_scopes(self):
self.assertTrue(self.credentials.has_scopes('foo'))
Expand Down
92 changes: 74 additions & 18 deletions tests/test_service_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
Unit tests for service account credentials implemented using RSA.
"""

import datetime
import json
import os
import rsa
import time
import unittest

import mock

from .http_mock import HttpMockSequence
from oauth2client.service_account import _ServiceAccountCredentials

Expand Down Expand Up @@ -88,39 +90,93 @@ def test_create_scoped(self):
_ServiceAccountCredentials))
self.assertEqual('dummy_scope', new_credentials._scopes)

def test_access_token(self):
S = 2 # number of seconds in which the token expires
token_response_first = {'access_token': 'first_token', 'expires_in': S}
token_response_second = {'access_token': 'second_token',
'expires_in': S}
@mock.patch('oauth2client.client._UTCNOW')
@mock.patch('rsa.pkcs1.sign', return_value=b'signed-value')
def test_access_token(self, sign_func, utcnow):
# Configure the patch.
seconds = 11
NOW = datetime.datetime(1992, 12, 31, second=seconds)
utcnow.return_value = NOW

lifetime = 2 # number of seconds in which the token expires
EXPIRY_TIME = datetime.datetime(1992, 12, 31,
second=seconds + lifetime)

token1 = u'first_token'
token_response_first = {
'access_token': token1,
'expires_in': lifetime,
}
token2 = u'second_token'
token_response_second = {
'access_token': token2,
'expires_in': lifetime,
}
http = HttpMockSequence([
({'status': '200'},
json.dumps(token_response_first).encode('utf-8')),
({'status': '200'},
json.dumps(token_response_second).encode('utf-8')),
])

token = self.credentials.get_access_token(http=http)
self.assertEqual('first_token', token.access_token)
self.assertEqual(S - 1, token.expires_in)
# Get Access Token, First attempt.
self.assertEqual(self.credentials.access_token, None)
self.assertFalse(self.credentials.access_token_expired)
self.assertEqual(token_response_first, self.credentials.token_response)

self.assertEqual(self.credentials.token_expiry, None)
token = self.credentials.get_access_token(http=http)
self.assertEqual('first_token', token.access_token)
self.assertEqual(S - 1, token.expires_in)
self.assertEqual(self.credentials.token_expiry, EXPIRY_TIME)
self.assertEqual(token1, token.access_token)
self.assertEqual(lifetime, token.expires_in)
self.assertEqual(token_response_first,
self.credentials.token_response)
# Two utcnow calls are expected:
# - get_access_token() -> _do_refresh_request (setting expires in)
# - get_access_token() -> _expires_in()
expected_utcnow_calls = [mock.call()] * 2
self.assertEqual(expected_utcnow_calls, utcnow.mock_calls)
# One rsa.pkcs1.sign expected: Actual refresh was needed.
self.assertEqual(len(sign_func.mock_calls), 1)

# Get Access Token, Second Attempt (not expired)
self.assertEqual(self.credentials.access_token, token1)
self.assertFalse(self.credentials.access_token_expired)
token = self.credentials.get_access_token(http=http)
# Make sure no refresh occurred since the token was not expired.
self.assertEqual(token1, token.access_token)
self.assertEqual(lifetime, token.expires_in)
self.assertEqual(token_response_first, self.credentials.token_response)

time.sleep(S + 0.5) # some margin to avoid flakiness
# Three more utcnow calls are expected:
# - access_token_expired
# - get_access_token() -> access_token_expired
# - get_access_token -> _expires_in
expected_utcnow_calls = [mock.call()] * (2 + 3)
self.assertEqual(expected_utcnow_calls, utcnow.mock_calls)
# No rsa.pkcs1.sign expected: the token was not expired.
self.assertEqual(len(sign_func.mock_calls), 1 + 0)

# Get Access Token, Third Attempt (force expiration)
self.assertEqual(self.credentials.access_token, token1)
self.credentials.token_expiry = NOW # Manually force expiry.
self.assertTrue(self.credentials.access_token_expired)

token = self.credentials.get_access_token(http=http)
self.assertEqual('second_token', token.access_token)
self.assertEqual(S - 1, token.expires_in)
# Make sure refresh occurred since the token was not expired.
self.assertEqual(token2, token.access_token)
self.assertEqual(lifetime, token.expires_in)
self.assertFalse(self.credentials.access_token_expired)
self.assertEqual(token_response_second,
self.credentials.token_response)
# Five more utcnow calls are expected:
# - access_token_expired
# - get_access_token -> access_token_expired
# - get_access_token -> _do_refresh_request
# - get_access_token -> _expires_in
# - access_token_expired
expected_utcnow_calls = [mock.call()] * (2 + 3 + 5)
self.assertEqual(expected_utcnow_calls, utcnow.mock_calls)
# One more rsa.pkcs1.sign expected: Actual refresh was needed.
self.assertEqual(len(sign_func.mock_calls), 1 + 0 + 1)

self.assertEqual(self.credentials.access_token, token2)


if __name__ == '__main__': # pragma: NO COVER
Expand Down

0 comments on commit 30c342a

Please sign in to comment.