diff --git a/aiohttp/pytest_plugins.py b/aiohttp/pytest_plugins.py new file mode 100644 index 00000000000..830e4d7e7cd --- /dev/null +++ b/aiohttp/pytest_plugins.py @@ -0,0 +1,61 @@ +import asyncio +import contextlib + +import pytest + +from .test_utils import TestClient, loop_context, setup_test_loop, teardown_test_loop + + +@contextlib.contextmanager +def _passthrough_loop_context(loop): + if loop: + # loop already exists, pass it straight through + yield loop + else: + # this shadows loop_context's standard behavior + loop = setup_test_loop() + yield loop + teardown_test_loop(loop) + + +def pytest_pycollect_makeitem(collector, name, obj): + """ + Fix pytest collecting for coroutines. + """ + if collector.funcnamefilter(name) and asyncio.iscoroutinefunction(obj): + return list(collector._genfunctions(name, obj)) + + +def pytest_pyfunc_call(pyfuncitem): + """ + Run coroutines in an event loop instead of a normal function call. + """ + if asyncio.iscoroutinefunction(pyfuncitem.function): + with _passthrough_loop_context(pyfuncitem.funcargs.get('loop')) as _loop: + testargs = {arg: pyfuncitem.funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames} + _loop.run_until_complete(_loop.create_task(pyfuncitem.obj(**testargs))) + + return True + + +@pytest.yield_fixture +def loop(): + with loop_context() as _loop: + yield _loop + + +@pytest.yield_fixture +def test_client(loop): + client = None + + async def _create_from_app_factory(app_factory): + nonlocal client + app = app_factory(loop) + client = TestClient(app) + await client.start_server() + return client + + yield _create_from_app_factory + + if client: + client.close() diff --git a/aiohttp/test_utils.py b/aiohttp/test_utils.py index 3d6b21cdcad..ceee2a2341e 100644 --- a/aiohttp/test_utils.py +++ b/aiohttp/test_utils.py @@ -333,21 +333,26 @@ class TestClient: TestClient can also be used as a contextmanager, returning the instance of itself instantiated. """ + _address = '127.0.0.1' def __init__(self, app, protocol="http"): self._app = app self._loop = loop = app.loop self.port = unused_port() - self._handler = handler = app.make_handler() - self._server = loop.run_until_complete(loop.create_server( - handler, '127.0.0.1', self.port - )) + self._handler = app.make_handler() + self._server = None + if not loop.is_running(): + loop.run_until_complete(self.start_server()) self._session = ClientSession(loop=self._loop) - self._root = "{}://127.0.0.1:{}".format( - protocol, self.port - ) + self._root = '{}://{}:{}'.format(protocol, self._address, self.port) self._closed = False + @asyncio.coroutine + def start_server(self): + self._server = yield from self._loop.create_server( + self._handler, self._address, self.port + ) + @property def session(self): """a raw handler to the aiohttp.ClientSession. unlike the methods on