diff --git a/aioes/transport.py b/aioes/transport.py index 81ae3f5..660b206 100644 --- a/aioes/transport.py +++ b/aioes/transport.py @@ -46,8 +46,9 @@ class Transport: def __init__(self, endpoints, *, sniffer_interval=None, sniffer_timeout=0.1, max_retries=3, - loop, verify_ssl=True): + loop, verify_ssl=True, connector_factory=lambda: None): self._loop = loop + self._connector_factory = connector_factory self._endpoints = self._convert_endpoints(endpoints) self._pool = ConnectionPool([], loop=loop) self._verify_ssl = verify_ssl @@ -148,7 +149,8 @@ def _reinitialize_endpoints(self): connections.append(Connection( endpoint, loop=self._loop, - verify_ssl=self._verify_ssl)) + verify_ssl=self._verify_ssl, + connector=self._connector_factory())) self._pool.close() random.shuffle(connections) self._pool = ConnectionPool(connections, loop=self._loop) diff --git a/tests/test_transport.py b/tests/test_transport.py index 020326e..c2b1df3 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -1,3 +1,4 @@ +import aiohttp import asyncio import time import urllib.parse @@ -30,6 +31,27 @@ def test_ctor(make_transport, es_params): assert 1 == len(tr._pool.connections) +@asyncio.coroutine +def test_connector_factory(es_params, loop): + + class TCPConnector(aiohttp.TCPConnector): + used = False + + def __init__(self, *args, **kwargs): + TCPConnector.used = True + super(TCPConnector, self).__init__(*args, **kwargs) + + tr = Transport( + endpoints=[{'host': es_params['host']}], + sniffer_interval=None, + loop=loop, + connector_factory=lambda: TCPConnector(loop=loop) + ) + assert 1 == len(tr._pool.connections) + assert TCPConnector.used + tr.close() + + @asyncio.coroutine def test_simple(make_transport): tr = make_transport()