Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding a new transport class to handle Phoenix channels #100

Merged
merged 25 commits into from
Sep 7, 2020
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
5fd3e67
New transport class to handle Phoenix channels
leruaa Jun 6, 2020
f085373
Adding forgotten typing hints
leruaa Jun 7, 2020
5e6018e
Fix _handle_answer args order
leruaa Jun 7, 2020
e1eccbc
Better handle exceptions
leruaa Jun 7, 2020
aeb18ec
Simulate complete messages in _send_stop_message
leruaa Jun 8, 2020
9c87006
Cancel heartbeat task in _close_coro
leruaa Jun 8, 2020
b635138
Remove useless close override
leruaa Jun 8, 2020
b7bd5a3
Set a correct ref in the heartbeat message
leruaa Jun 17, 2020
0cb606f
Merge branch 'master' into feature-phoenix-channel
leruaa Jun 17, 2020
4d6de48
Adding unit tests for PhoenixChannelWebsocketsTransport exceptions
leruaa Jun 17, 2020
76fa98a
Increase PhoenixChannelWebsocketsTransport unit tests coverage
leruaa Jun 18, 2020
dd6101f
Better handle a case when there are multiple errors in a query
leruaa Jun 23, 2020
7047c3e
Merge branch 'master' into feature-phoenix-channel
leruaa Jun 29, 2020
33d51e6
Fix mypy errors
leruaa Jun 29, 2020
a571bd2
Merge branch 'master' into feature-phoenix-channel
leruaa Jun 29, 2020
c742487
Merge branch 'master' into feature-phoenix-channel
leruaa Jul 11, 2020
f0f62be
Update WebSocketServer to WebSocketServerHelper in test_websocket_non…
leruaa Jul 11, 2020
28ad41d
Adding doc on PhoenixChannelWebsocketsTransport
leruaa Jul 11, 2020
c10b851
Merge branch 'master' into feature-phoenix-channel
leruaa Aug 15, 2020
bd72e79
DSL: Fixed bug where a nested GraphQLInputObjectType is causing infin…
JBrVJxsc Aug 17, 2020
1a2dcec
Fix race condition in websocket transport close (#133)
leszekhanusz Aug 27, 2020
29f7f2b
add the data property in TransportQueryError (#136)
leszekhanusz Aug 28, 2020
3c82d1e
Fix running execute and subscribe of client in a Thread (#135)
leszekhanusz Sep 7, 2020
b4ab941
Allow to import PhoenixChannelWebsocketsTransport directly from gql &…
leszekhanusz Sep 7, 2020
bec8e66
Merge branch 'master' into feature-phoenix-channel
leszekhanusz Sep 7, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
244 changes: 244 additions & 0 deletions gql/transport/phoenix_channel_websockets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
import asyncio
import json
from typing import Dict, Optional, Tuple

from graphql import DocumentNode, ExecutionResult, print_ast
from websockets.exceptions import ConnectionClosed

from .exceptions import (
TransportProtocolError,
TransportQueryError,
TransportServerError,
)
from .websockets import WebsocketsTransport


class PhoenixChannelWebsocketsTransport(WebsocketsTransport):
def __init__(
self, channel_name: str, heartbeat_interval: int = 30, *args, **kwargs
) -> None:
self.channel_name = channel_name
self.heartbeat_interval = heartbeat_interval
self.subscription_ids_to_query_ids: Dict[str, int] = {}
super(PhoenixChannelWebsocketsTransport, self).__init__(*args, **kwargs)
leruaa marked this conversation as resolved.
Show resolved Hide resolved

async def _send_init_message_and_wait_ack(self) -> None:
"""Join the specified channel and wait for the connection ACK.

If the answer is not a connection_ack message, we will return an Exception.
"""

query_id = self.next_query_id
self.next_query_id += 1

init_message = json.dumps(
{
"topic": self.channel_name,
"event": "phx_join",
"payload": {},
"ref": query_id,
}
)
leszekhanusz marked this conversation as resolved.
Show resolved Hide resolved

await self._send(init_message)

# Wait for the connection_ack message or raise a TimeoutError
init_answer = await asyncio.wait_for(self._receive(), self.ack_timeout)

answer_type, answer_id, execution_result = self._parse_answer(init_answer)

if answer_type != "reply":
raise TransportProtocolError(
"Websocket server did not return a connection ack"
)

async def heartbeat_coro():
while True:
await asyncio.sleep(self.heartbeat_interval)
try:
query_id = self.next_query_id
self.next_query_id += 1

await self._send(
json.dumps(
{
"topic": "phoenix",
"event": "heartbeat",
"payload": {},
"ref": query_id,
}
)
)
except ConnectionClosed:
pass

self.heartbeat_task = asyncio.ensure_future(heartbeat_coro())

async def _send_stop_message(self, query_id: int) -> None:
try:
await self.listeners[query_id].put(("complete", None))
except KeyError:
pass

async def _send_connection_terminate_message(self) -> None:
"""Send a phx_leave message to disconnect from the provided channel.
"""

query_id = self.next_query_id
self.next_query_id += 1

connection_terminate_message = json.dumps(
{
"topic": self.channel_name,
"event": "phx_leave",
"payload": {},
"ref": query_id,
}
)

await self._send(connection_terminate_message)

async def _send_query(
self,
document: DocumentNode,
variable_values: Optional[Dict[str, str]] = None,
operation_name: Optional[str] = None,
) -> int:
"""Send a query to the provided websocket connection.

We use an incremented id to reference the query.

Returns the used id for this query.
"""

query_id = self.next_query_id
self.next_query_id += 1

query_str = json.dumps(
{
"topic": self.channel_name,
"event": "doc",
"payload": {
"query": print_ast(document),
"variables": variable_values or {},
},
"ref": query_id,
}
)

await self._send(query_str)

return query_id

def _parse_answer(
self, answer: str
) -> Tuple[str, Optional[int], Optional[ExecutionResult]]:
"""Parse the answer received from the server

Returns a list consisting of:
- the answer_type (between:
'heartbeat', 'data', 'reply', 'error', 'close')
- the answer id (Integer) if received or None
- an execution Result if the answer_type is 'data' or None
"""

event: str = ""
answer_id: Optional[int] = None
answer_type: str = ""
execution_result: Optional[ExecutionResult] = None

try:
json_answer = json.loads(answer)

event = str(json_answer.get("event"))

if event == "subscription:data":
payload = json_answer.get("payload")

if not isinstance(payload, dict):
raise ValueError("payload is not a dict")

subscription_id = str(payload.get("subscriptionId"))
try:
answer_id = self.subscription_ids_to_query_ids[subscription_id]
except KeyError:
raise ValueError(
f"subscription '{subscription_id}' has not been registerd"
)

result = payload.get("result")

if not isinstance(result, dict):
raise ValueError("result is not a dict")

answer_type = "data"

execution_result = ExecutionResult(
errors=payload.get("errors"), data=result.get("data")
)

elif event == "phx_reply":
answer_id = int(json_answer.get("ref"))
payload = json_answer.get("payload")

if not isinstance(payload, dict):
raise ValueError("payload is not a dict")

status = str(payload.get("status"))

if status == "ok":

answer_type = "reply"
response = payload.get("response")

if isinstance(response, dict) and "subscriptionId" in response:
subscription_id = str(response.get("subscriptionId"))
self.subscription_ids_to_query_ids[subscription_id] = answer_id

elif status == "error":
response = payload.get("response")

if isinstance(response, dict):
if "errors" in response:
raise TransportQueryError(
str(response.get("errors")), query_id=answer_id
)
elif "reason" in response:
raise TransportQueryError(
str(response.get("reason")), query_id=answer_id
)
raise ValueError("reply error")

elif status == "timeout":
raise TransportQueryError("reply timeout", query_id=answer_id)

elif event == "phx_error":
raise TransportServerError("Server error")
elif event == "phx_close":
answer_type = "close"
else:
raise ValueError

except ValueError as e:
raise TransportProtocolError(
"Server did not return a GraphQL result"
) from e

return answer_type, answer_id, execution_result

async def _handle_answer(
self,
answer_type: str,
answer_id: Optional[int],
execution_result: Optional[ExecutionResult],
) -> None:
if answer_type == "close":
await self.close()
else:
await super()._handle_answer(answer_type, answer_id, execution_result)

async def _close_coro(self, e: Exception, clean_close: bool = True) -> None:
if self.heartbeat_task is not None:
self.heartbeat_task.cancel()

await super()._close_coro(e, clean_close)
24 changes: 15 additions & 9 deletions gql/transport/websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,19 +355,25 @@ async def _receive_data_loop(self) -> None:
await self._fail(e, clean_close=False)
break

try:
# Put the answer in the queue
if answer_id is not None:
await self.listeners[answer_id].put(
(answer_type, execution_result)
)
except KeyError:
# Do nothing if no one is listening to this query_id.
pass
await self._handle_answer(answer_type, answer_id, execution_result)

finally:
log.debug("Exiting _receive_data_loop()")

async def _handle_answer(
self,
answer_type: str,
answer_id: Optional[int],
execution_result: Optional[ExecutionResult],
) -> None:
try:
# Put the answer in the queue
if answer_id is not None:
await self.listeners[answer_id].put((answer_type, execution_result))
except KeyError:
# Do nothing if no one is listening to this query_id.
pass

async def subscribe(
self,
document: DocumentNode,
Expand Down
28 changes: 25 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ async def stop(self):

print("Server stopped\n\n\n")


class WebSocketServerHelper:
@staticmethod
async def send_complete(ws, query_id):
await ws.send(f'{{"type":"complete","id":"{query_id}","payload":null}}')
Expand Down Expand Up @@ -164,6 +166,26 @@ async def wait_connection_terminate(ws):
assert json_result["type"] == "connection_terminate"


class PhoenixChannelServerHelper:
@staticmethod
async def send_close(ws):
await ws.send('{"event":"phx_close"}')

@staticmethod
async def send_connection_ack(ws):

# Line return for easy debugging
print("")

# Wait for init
result = await ws.recv()
json_result = json.loads(result)
assert json_result["event"] == "phx_join"

# Send ack
await ws.send('{"event":"phx_reply", "payload": {"status": "ok"}, "ref": 1}')


def get_server_handler(request):
"""Get the server handler.

Expand All @@ -180,7 +202,7 @@ def get_server_handler(request):
async def default_server_handler(ws, path):

try:
await WebSocketServer.send_connection_ack(ws)
await WebSocketServerHelper.send_connection_ack(ws)
query_id = 1

for answer in answers:
Expand All @@ -194,10 +216,10 @@ async def default_server_handler(ws, path):
formatted_answer = answer

await ws.send(formatted_answer)
await WebSocketServer.send_complete(ws, query_id)
await WebSocketServerHelper.send_complete(ws, query_id)
query_id += 1

await WebSocketServer.wait_connection_terminate(ws)
await WebSocketServerHelper.wait_connection_terminate(ws)
await ws.wait_closed()
except ConnectionClosed:
pass
Expand Down
8 changes: 4 additions & 4 deletions tests/test_async_client_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from gql import Client, gql
from gql.transport.websockets import WebsocketsTransport

from .conftest import MS, WebSocketServer
from .conftest import MS, WebSocketServerHelper
from .starwars.schema import StarWarsIntrospection, StarWarsSchema, StarWarsTypeDef

starwars_expected_one = {
Expand All @@ -25,7 +25,7 @@


async def server_starwars(ws, path):
await WebSocketServer.send_connection_ack(ws)
await WebSocketServerHelper.send_connection_ack(ws)

try:
await ws.recv()
Expand All @@ -42,8 +42,8 @@ async def server_starwars(ws, path):
await ws.send(data)
await asyncio.sleep(2 * MS)

await WebSocketServer.send_complete(ws, 1)
await WebSocketServer.wait_connection_terminate(ws)
await WebSocketServerHelper.send_complete(ws, 1)
await WebSocketServerHelper.wait_connection_terminate(ws)

except websockets.exceptions.ConnectionClosedOK:
pass
Expand Down
Loading