diff --git a/gcloud/connection.py b/gcloud/connection.py index b7321134680f..00505eb11063 100644 --- a/gcloud/connection.py +++ b/gcloud/connection.py @@ -15,12 +15,13 @@ """Shared implementation of connections to API servers.""" import json +import threading + +import httplib2 from pkg_resources import get_distribution import six from six.moves.urllib.parse import urlencode # pylint: disable=F0401 -import httplib2 - from gcloud.exceptions import make_exception @@ -55,6 +56,8 @@ class Connection(object): object will also need to be able to add a bearer token to API requests and handle token refresh on 401 errors. + A custom ``http`` object will also need to ensure its own thread safety. + :type credentials: :class:`oauth2client.client.OAuth2Credentials` or :class:`NoneType` :param credentials: The OAuth2 Credentials to use for this connection. @@ -73,6 +76,7 @@ class Connection(object): """ def __init__(self, credentials=None, http=None): + self._local = threading.local() self._http = http self._credentials = self._create_scoped_credentials( credentials, self.SCOPE) @@ -91,14 +95,32 @@ def credentials(self): def http(self): """A getter for the HTTP transport used in talking to the API. - :rtype: :class:`httplib2.Http` - :returns: A Http object used to transport data. + This will return a thread-local :class:`httplib2.Http` instance unless + a custom transport has been provided to the :class:`Connection` + constructor. + + :rtype: :class:`httplib2.Http` or the custom HTTP transport specifed + to the connection constructor. + :returns: An ``Http`` object used to transport data. """ - if self._http is None: - self._http = httplib2.Http() + if self._http is not None: + return self._http + + if not hasattr(self._local, 'http'): + self._local.http = httplib2.Http() + + # NOTE: Because this checks the existance of the credentials before + # using it, this is not thread safe. Another thread could change + # self._credentials between the if and the call to authorize. + # However, self._credentials is read-only and set at connection + # creation time, so it should never be change during normal + # circumstances. A lock or EAFP could mitigate this, but we don't + # see it necessary right now. if self._credentials: - self._http = self._credentials.authorize(self._http) - return self._http + self._local.http = self._credentials.authorize( + self._local.http) + + return self._local.http @staticmethod def _create_scoped_credentials(credentials, scope): diff --git a/gcloud/test_connection.py b/gcloud/test_connection.py index 0af98a821f52..9e0832ca108c 100644 --- a/gcloud/test_connection.py +++ b/gcloud/test_connection.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import threading import unittest2 @@ -66,6 +67,36 @@ def test_user_agent_format(self): conn = self._makeOne() self.assertEqual(conn.USER_AGENT, expected_ua) + def test_thread_local_http(self): + credentials = _Credentials(lambda http: object()) + conn = self._makeOne(credentials) + + self.assertTrue(conn.http is not None) + + # Should return the same instance when called again. + self.assertTrue(conn.http is conn.http) + + # Should return a different instance from a different thread. + http_main = conn.http + http_objs = [] + + def test_thread(): + self.assertTrue(conn.http is not None) + self.assertTrue(conn.http is not http_main) + http_objs.append(conn.http) + + thread1 = threading.Thread(target=test_thread) + thread2 = threading.Thread(target=test_thread) + thread1.start() + thread2.start() + thread1.join() + thread2.join() + + self.assertEqual(len(http_objs), 2) + self.assertTrue(http_objs[0] is not http_objs[1]) + self.assertTrue(http_objs[0] is not http_main) + self.assertTrue(http_objs[1] is not http_main) + class TestJSONConnection(unittest2.TestCase): @@ -374,7 +405,11 @@ def __init__(self, authorized=None): def authorize(self, http): self._called_with = http - return self._authorized + + if callable(self._authorized): + return self._authorized(http) + else: + return self._authorized @staticmethod def create_scoped_required():