From e332a51cc8cd8eb98ebd54a81a93e573f51ef6d2 Mon Sep 17 00:00:00 2001 From: Danny Hermes Date: Wed, 10 Aug 2016 18:43:38 -0700 Subject: [PATCH] Use transport module for GCE environment check. Fixes #599. --- oauth2client/client.py | 22 +++++------- tests/test_client.py | 79 +++++++++++++++++------------------------- 2 files changed, 41 insertions(+), 60 deletions(-) diff --git a/oauth2client/client.py b/oauth2client/client.py index 77106db11..0497d074f 100644 --- a/oauth2client/client.py +++ b/oauth2client/client.py @@ -108,9 +108,10 @@ GCE_METADATA_TIMEOUT = 3 _SERVER_SOFTWARE = 'SERVER_SOFTWARE' -_GCE_METADATA_HOST = '169.254.169.254' -_METADATA_FLAVOR_HEADER = 'Metadata-Flavor' +_GCE_METADATA_URI = 'http://169.254.169.254' +_METADATA_FLAVOR_HEADER = 'metadata-flavor' # lowercase header _DESIRED_METADATA_FLAVOR = 'Google' +_GCE_HEADERS = {_METADATA_FLAVOR_HEADER: _DESIRED_METADATA_FLAVOR} # Expose utcnow() at module level to allow for # easier testing (by replacing with a stub). @@ -997,21 +998,16 @@ def _detect_gce_environment(): # could lead to false negatives in the event that we are on GCE, but # the metadata resolution was particularly slow. The latter case is # "unlikely". - connection = six.moves.http_client.HTTPConnection( - _GCE_METADATA_HOST, timeout=GCE_METADATA_TIMEOUT) - + http = transport.get_http_object(timeout=GCE_METADATA_TIMEOUT) try: - headers = {_METADATA_FLAVOR_HEADER: _DESIRED_METADATA_FLAVOR} - connection.request('GET', '/', headers=headers) - response = connection.getresponse() - if response.status == http_client.OK: - return (response.getheader(_METADATA_FLAVOR_HEADER) == - _DESIRED_METADATA_FLAVOR) + response, _ = transport.request( + http, _GCE_METADATA_URI, headers=_GCE_HEADERS) + return ( + response.status == http_client.OK and + response.get(_METADATA_FLAVOR_HEADER) == _DESIRED_METADATA_FLAVOR) except socket.error: # socket.timeout or socket.error(64, 'Host is down') logger.info('Timeout attempting to reach GCE metadata service.') return False - finally: - connection.close() def _in_gae_environment(): diff --git a/tests/test_client.py b/tests/test_client.py index 4666ced62..dbe11ebaf 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -42,8 +42,6 @@ DATA_DIR = os.path.join(os.path.dirname(__file__), 'data') -# TODO(craigcitro): This is duplicated from -# googleapiclient.test_discovery; consolidate these definitions. def assertUrisEqual(testcase, expected, actual): """Test that URIs are the same, up to reordering of query parameters.""" expected = urllib.parse.urlparse(expected) @@ -357,67 +355,41 @@ def test_environment_caching(self): # is cached. self.assertTrue(client._in_gae_environment()) - def _environment_check_gce_helper(self, status_ok=True, socket_error=False, + def _environment_check_gce_helper(self, status_ok=True, server_software=''): - response = mock.Mock() if status_ok: - response.status = http_client.OK - response.getheader = mock.Mock( - name='getheader', - return_value=client._DESIRED_METADATA_FLAVOR) + headers = {'status': http_client.OK} + headers.update(client._GCE_HEADERS) else: - response.status = http_client.NOT_FOUND - - connection = mock.Mock() - connection.getresponse = mock.Mock(name='getresponse', - return_value=response) - if socket_error: - connection.getresponse.side_effect = socket.error() + headers = {'status': http_client.NOT_FOUND} + http = http_mock.HttpMock(headers=headers) with mock.patch('oauth2client.client.os') as os_module: os_module.environ = {client._SERVER_SOFTWARE: server_software} - with mock.patch('oauth2client.client.six') as six_module: - http_client_module = six_module.moves.http_client - http_client_module.HTTPConnection = mock.Mock( - name='HTTPConnection', return_value=connection) - + with mock.patch('oauth2client.transport.get_http_object', + return_value=http) as new_http: if server_software == '': self.assertFalse(client._in_gae_environment()) else: self.assertTrue(client._in_gae_environment()) - if status_ok and not socket_error and server_software == '': + if status_ok and server_software == '': self.assertTrue(client._in_gce_environment()) else: self.assertFalse(client._in_gce_environment()) + # Verify mocks. if server_software == '': - http_client_module.HTTPConnection.assert_called_once_with( - client._GCE_METADATA_HOST, + new_http.assert_called_once_with( timeout=client.GCE_METADATA_TIMEOUT) - connection.getresponse.assert_called_once_with() - # Remaining calls are not "getresponse" - headers = { - client._METADATA_FLAVOR_HEADER: ( - client._DESIRED_METADATA_FLAVOR), - } - self.assertEqual(connection.method_calls, [ - mock.call.request('GET', '/', - headers=headers), - mock.call.close(), - ]) - self.assertEqual(response.method_calls, []) - if status_ok and not socket_error: - response.getheader.assert_called_once_with( - client._METADATA_FLAVOR_HEADER) + self.assertEqual(http.requests, 1) + self.assertEqual(http.uri, client._GCE_METADATA_URI) + self.assertEqual(http.method, 'GET') + self.assertIsNone(http.body) + self.assertEqual(http.headers, client._GCE_HEADERS) else: - self.assertEqual( - http_client_module.HTTPConnection.mock_calls, []) - self.assertEqual(connection.getresponse.mock_calls, []) - # Remaining calls are not "getresponse" - self.assertEqual(connection.method_calls, []) - self.assertEqual(response.method_calls, []) - self.assertEqual(response.getheader.mock_calls, []) + new_http.assert_not_called() + self.assertEqual(http.requests, 0) def test_environment_check_gce_production(self): self._environment_check_gce_helper(status_ok=True) @@ -426,8 +398,21 @@ def test_environment_check_gce_prod_with_working_gae_imports(self): with mock_module_import('google.appengine'): self._environment_check_gce_helper(status_ok=True) - def test_environment_check_gce_timeout(self): - self._environment_check_gce_helper(socket_error=True) + @mock.patch('oauth2client.client.os.environ', + new={client._SERVER_SOFTWARE: ''}) + @mock.patch('oauth2client.transport.get_http_object', + return_value=object()) + @mock.patch('oauth2client.transport.request', + side_effect=socket.timeout()) + def test_environment_check_gce_timeout(self, mock_request, new_http): + self.assertFalse(client._in_gae_environment()) + self.assertFalse(client._in_gce_environment()) + + # Verify mocks. + new_http.assert_called_once_with(timeout=client.GCE_METADATA_TIMEOUT) + mock_request.assert_called_once_with( + new_http.return_value, client._GCE_METADATA_URI, + headers=client._GCE_HEADERS) def test_environ_check_gae_module_unknown(self): with mock_module_import('google.appengine'):