Skip to content

Commit

Permalink
Fix running execute and subscribe of client in a Thread (#135)
Browse files Browse the repository at this point in the history
  • Loading branch information
leszekhanusz authored Sep 7, 2020
1 parent 38d6c87 commit 0acea14
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 25 deletions.
18 changes: 15 additions & 3 deletions gql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,13 @@ def execute(self, document: DocumentNode, *args, **kwargs) -> Dict:

if isinstance(self.transport, AsyncTransport):

loop = asyncio.get_event_loop()
# Get the current asyncio event loop
# Or create a new event loop if there isn't one (in a new Thread)
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

assert not loop.is_running(), (
"Cannot run client.execute(query) if an asyncio loop is running."
Expand Down Expand Up @@ -146,9 +152,15 @@ def subscribe(
We need an async transport for this functionality.
"""

async_generator = self.subscribe_async(document, *args, **kwargs)
# Get the current asyncio event loop
# Or create a new event loop if there isn't one (in a new Thread)
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

loop = asyncio.get_event_loop()
async_generator = self.subscribe_async(document, *args, **kwargs)

assert not loop.is_running(), (
"Cannot run client.subscribe(query) if an asyncio loop is running."
Expand Down
8 changes: 8 additions & 0 deletions gql/transport/websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,14 @@ def __init__(
self.receive_data_task: Optional[asyncio.Future] = None
self.close_task: Optional[asyncio.Future] = None

# We need to set an event loop here if there is none
# Or else we will not be able to create an asyncio.Event()
try:
self._loop = asyncio.get_event_loop()
except RuntimeError:
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)

self._wait_closed: asyncio.Event = asyncio.Event()
self._wait_closed.set()

Expand Down
19 changes: 19 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pathlib
import ssl
import types
from concurrent.futures import ThreadPoolExecutor

import pytest
import websockets
Expand Down Expand Up @@ -266,3 +267,21 @@ async def client_and_server(server):

# Yield both client session and server
yield session, server


@pytest.fixture
async def run_sync_test():
async def run_sync_test_inner(event_loop, server, test_function):
"""This function will run the test in a different Thread.
This allows us to run sync code while aiohttp server can still run.
"""
executor = ThreadPoolExecutor(max_workers=2)
test_task = event_loop.run_in_executor(executor, test_function)

await test_task

if hasattr(server, "close"):
await server.close()

return run_sync_test_inner
59 changes: 59 additions & 0 deletions tests/test_aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,3 +262,62 @@ async def handler(request):
continent = result["continent"]

assert continent["name"] == "Europe"


@pytest.mark.asyncio
async def test_aiohttp_execute_running_in_thread(
event_loop, aiohttp_server, run_sync_test
):
async def handler(request):
return web.Response(text=query1_server_answer, content_type="application/json")

app = web.Application()
app.router.add_route("POST", "/", handler)
server = await aiohttp_server(app)

url = server.make_url("/")

def test_code():
sample_transport = AIOHTTPTransport(url=url)

client = Client(transport=sample_transport)

query = gql(query1_str)

client.execute(query)

await run_sync_test(event_loop, server, test_code)


@pytest.mark.asyncio
async def test_aiohttp_subscribe_running_in_thread(
event_loop, aiohttp_server, run_sync_test
):
async def handler(request):
return web.Response(text=query1_server_answer, content_type="application/json")

app = web.Application()
app.router.add_route("POST", "/", handler)
server = await aiohttp_server(app)

url = server.make_url("/")

def test_code():
sample_transport = AIOHTTPTransport(url=url)

client = Client(transport=sample_transport)

query = gql(query1_str)

# Note: subscriptions are not supported on the aiohttp transport
# But we add this test in order to have 100% code coverage
# It is to check that we will correctly set an event loop
# in the subscribe function if there is none (in a Thread for example)
# We cannot test this with the websockets transport because
# the websockets transport will set an event loop in its init

with pytest.raises(NotImplementedError):
for result in client.subscribe(query):
pass

await run_sync_test(event_loop, server, test_code)
33 changes: 11 additions & 22 deletions tests/test_requests.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
from concurrent.futures import ThreadPoolExecutor

import pytest
from aiohttp import web

from gql import Client, gql
from gql import Client, RequestsHTTPTransport, gql
from gql.transport.exceptions import (
TransportAlreadyConnected,
TransportClosed,
TransportProtocolError,
TransportQueryError,
TransportServerError,
)
from gql.transport.requests import RequestsHTTPTransport

query1_str = """
query getContinents {
Expand All @@ -31,20 +28,8 @@
)


async def run_sync_test(event_loop, server, test_function):
"""This function will run the test in a different Thread.
This allows us to run sync code while aiohttp server can still run.
"""
executor = ThreadPoolExecutor(max_workers=2)
test_task = event_loop.run_in_executor(executor, test_function)

await test_task
await server.close()


@pytest.mark.asyncio
async def test_requests_query(event_loop, aiohttp_server):
async def test_requests_query(event_loop, aiohttp_server, run_sync_test):
async def handler(request):
return web.Response(text=query1_server_answer, content_type="application/json")

Expand Down Expand Up @@ -74,7 +59,7 @@ def test_code():


@pytest.mark.asyncio
async def test_requests_error_code_500(event_loop, aiohttp_server):
async def test_requests_error_code_500(event_loop, aiohttp_server, run_sync_test):
async def handler(request):
# Will generate http error code 500
raise Exception("Server error")
Expand Down Expand Up @@ -102,7 +87,7 @@ def test_code():


@pytest.mark.asyncio
async def test_requests_error_code(event_loop, aiohttp_server):
async def test_requests_error_code(event_loop, aiohttp_server, run_sync_test):
async def handler(request):
return web.Response(
text=query1_server_error_answer, content_type="application/json"
Expand Down Expand Up @@ -136,7 +121,9 @@ def test_code():

@pytest.mark.asyncio
@pytest.mark.parametrize("response", invalid_protocol_responses)
async def test_requests_invalid_protocol(event_loop, aiohttp_server, response):
async def test_requests_invalid_protocol(
event_loop, aiohttp_server, response, run_sync_test
):
async def handler(request):
return web.Response(text=response, content_type="application/json")

Expand All @@ -160,7 +147,7 @@ def test_code():


@pytest.mark.asyncio
async def test_requests_cannot_connect_twice(event_loop, aiohttp_server):
async def test_requests_cannot_connect_twice(event_loop, aiohttp_server, run_sync_test):
async def handler(request):
return web.Response(text=query1_server_answer, content_type="application/json")

Expand All @@ -182,7 +169,9 @@ def test_code():


@pytest.mark.asyncio
async def test_requests_cannot_execute_if_not_connected(event_loop, aiohttp_server):
async def test_requests_cannot_execute_if_not_connected(
event_loop, aiohttp_server, run_sync_test
):
async def handler(request):
return web.Response(text=query1_server_answer, content_type="application/json")

Expand Down
29 changes: 29 additions & 0 deletions tests/test_websocket_subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,3 +446,32 @@ def test_websocket_subscription_sync_graceful_shutdown(server, subscription_str)

# Check that the server received a connection_terminate message last
assert logged_messages.pop() == '{"type": "connection_terminate"}'


@pytest.mark.asyncio
@pytest.mark.parametrize("server", [server_countdown], indirect=True)
@pytest.mark.parametrize("subscription_str", [countdown_subscription_str])
async def test_websocket_subscription_running_in_thread(
event_loop, server, subscription_str, run_sync_test
):
def test_code():
path = "/graphql"
url = f"ws://{server.hostname}:{server.port}{path}"
sample_transport = WebsocketsTransport(url=url)

client = Client(transport=sample_transport)

count = 10
subscription = gql(subscription_str.format(count=count))

for result in client.subscribe(subscription):

number = result["number"]
print(f"Number received: {number}")

assert number == count
count -= 1

assert count == -1

await run_sync_test(event_loop, server, test_code)

0 comments on commit 0acea14

Please sign in to comment.