Skip to content

Commit

Permalink
Adding a new transport class to handle Phoenix channels (#100)
Browse files Browse the repository at this point in the history
  • Loading branch information
leruaa authored Sep 7, 2020
1 parent 0acea14 commit 706f789
Show file tree
Hide file tree
Showing 11 changed files with 768 additions and 46 deletions.
2 changes: 2 additions & 0 deletions gql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from .client import Client
from .gql import gql
from .transport.aiohttp import AIOHTTPTransport
from .transport.phoenix_channel_websockets import PhoenixChannelWebsocketsTransport
from .transport.requests import RequestsHTTPTransport
from .transport.websockets import WebsocketsTransport

__all__ = [
"gql",
"AIOHTTPTransport",
"Client",
"PhoenixChannelWebsocketsTransport",
"RequestsHTTPTransport",
"WebsocketsTransport",
]
250 changes: 250 additions & 0 deletions gql/transport/phoenix_channel_websockets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
import asyncio
import json
from typing import Dict, Optional, Tuple

from graphql import DocumentNode, ExecutionResult, print_ast
from websockets.exceptions import ConnectionClosed

from .exceptions import (
TransportProtocolError,
TransportQueryError,
TransportServerError,
)
from .websockets import WebsocketsTransport


class PhoenixChannelWebsocketsTransport(WebsocketsTransport):
def __init__(
self, channel_name: str, heartbeat_interval: float = 30, *args, **kwargs
) -> None:
self.channel_name = channel_name
self.heartbeat_interval = heartbeat_interval
self.subscription_ids_to_query_ids: Dict[str, int] = {}
super(PhoenixChannelWebsocketsTransport, self).__init__(*args, **kwargs)
"""Initialize the transport with the given request parameters.
:param channel_name Channel on the server this transport will join
:param heartbeat_interval Interval in second between each heartbeat messages
sent by the client
"""

async def _send_init_message_and_wait_ack(self) -> None:
"""Join the specified channel and wait for the connection ACK.
If the answer is not a connection_ack message, we will return an Exception.
"""

query_id = self.next_query_id
self.next_query_id += 1

init_message = json.dumps(
{
"topic": self.channel_name,
"event": "phx_join",
"payload": {},
"ref": query_id,
}
)

await self._send(init_message)

# Wait for the connection_ack message or raise a TimeoutError
init_answer = await asyncio.wait_for(self._receive(), self.ack_timeout)

answer_type, answer_id, execution_result = self._parse_answer(init_answer)

if answer_type != "reply":
raise TransportProtocolError(
"Websocket server did not return a connection ack"
)

async def heartbeat_coro():
while True:
await asyncio.sleep(self.heartbeat_interval)
try:
query_id = self.next_query_id
self.next_query_id += 1

await self._send(
json.dumps(
{
"topic": "phoenix",
"event": "heartbeat",
"payload": {},
"ref": query_id,
}
)
)
except ConnectionClosed: # pragma: no cover
return

self.heartbeat_task = asyncio.ensure_future(heartbeat_coro())

async def _send_stop_message(self, query_id: int) -> None:
try:
await self.listeners[query_id].put(("complete", None))
except KeyError: # pragma: no cover
pass

async def _send_connection_terminate_message(self) -> None:
"""Send a phx_leave message to disconnect from the provided channel.
"""

query_id = self.next_query_id
self.next_query_id += 1

connection_terminate_message = json.dumps(
{
"topic": self.channel_name,
"event": "phx_leave",
"payload": {},
"ref": query_id,
}
)

await self._send(connection_terminate_message)

async def _send_query(
self,
document: DocumentNode,
variable_values: Optional[Dict[str, str]] = None,
operation_name: Optional[str] = None,
) -> int:
"""Send a query to the provided websocket connection.
We use an incremented id to reference the query.
Returns the used id for this query.
"""

query_id = self.next_query_id
self.next_query_id += 1

query_str = json.dumps(
{
"topic": self.channel_name,
"event": "doc",
"payload": {
"query": print_ast(document),
"variables": variable_values or {},
},
"ref": query_id,
}
)

await self._send(query_str)

return query_id

def _parse_answer(
self, answer: str
) -> Tuple[str, Optional[int], Optional[ExecutionResult]]:
"""Parse the answer received from the server
Returns a list consisting of:
- the answer_type (between:
'heartbeat', 'data', 'reply', 'error', 'close')
- the answer id (Integer) if received or None
- an execution Result if the answer_type is 'data' or None
"""

event: str = ""
answer_id: Optional[int] = None
answer_type: str = ""
execution_result: Optional[ExecutionResult] = None

try:
json_answer = json.loads(answer)

event = str(json_answer.get("event"))

if event == "subscription:data":
payload = json_answer.get("payload")

if not isinstance(payload, dict):
raise ValueError("payload is not a dict")

subscription_id = str(payload.get("subscriptionId"))
try:
answer_id = self.subscription_ids_to_query_ids[subscription_id]
except KeyError:
raise ValueError(
f"subscription '{subscription_id}' has not been registerd"
)

result = payload.get("result")

if not isinstance(result, dict):
raise ValueError("result is not a dict")

answer_type = "data"

execution_result = ExecutionResult(
errors=payload.get("errors"), data=result.get("data")
)

elif event == "phx_reply":
answer_id = int(json_answer.get("ref"))
payload = json_answer.get("payload")

if not isinstance(payload, dict):
raise ValueError("payload is not a dict")

status = str(payload.get("status"))

if status == "ok":

answer_type = "reply"
response = payload.get("response")

if isinstance(response, dict) and "subscriptionId" in response:
subscription_id = str(response.get("subscriptionId"))
self.subscription_ids_to_query_ids[subscription_id] = answer_id

elif status == "error":
response = payload.get("response")

if isinstance(response, dict):
if "errors" in response:
raise TransportQueryError(
str(response.get("errors")), query_id=answer_id
)
elif "reason" in response:
raise TransportQueryError(
str(response.get("reason")), query_id=answer_id
)
raise ValueError("reply error")

elif status == "timeout":
raise TransportQueryError("reply timeout", query_id=answer_id)

elif event == "phx_error":
raise TransportServerError("Server error")
elif event == "phx_close":
answer_type = "close"
else:
raise ValueError

except ValueError as e:
raise TransportProtocolError(
"Server did not return a GraphQL result"
) from e

return answer_type, answer_id, execution_result

async def _handle_answer(
self,
answer_type: str,
answer_id: Optional[int],
execution_result: Optional[ExecutionResult],
) -> None:
if answer_type == "close":
await self.close()
else:
await super()._handle_answer(answer_type, answer_id, execution_result)

async def _close_coro(self, e: Exception, clean_close: bool = True) -> None:
if self.heartbeat_task is not None:
self.heartbeat_task.cancel()

await super()._close_coro(e, clean_close)
24 changes: 15 additions & 9 deletions gql/transport/websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,19 +371,25 @@ async def _receive_data_loop(self) -> None:
await self._fail(e, clean_close=False)
break

try:
# Put the answer in the queue
if answer_id is not None:
await self.listeners[answer_id].put(
(answer_type, execution_result)
)
except KeyError:
# Do nothing if no one is listening to this query_id.
pass
await self._handle_answer(answer_type, answer_id, execution_result)

finally:
log.debug("Exiting _receive_data_loop()")

async def _handle_answer(
self,
answer_type: str,
answer_id: Optional[int],
execution_result: Optional[ExecutionResult],
) -> None:
try:
# Put the answer in the queue
if answer_id is not None:
await self.listeners[answer_id].put((answer_type, execution_result))
except KeyError:
# Do nothing if no one is listening to this query_id.
pass

async def subscribe(
self,
document: DocumentNode,
Expand Down
28 changes: 25 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ async def stop(self):

print("Server stopped\n\n\n")


class WebSocketServerHelper:
@staticmethod
async def send_complete(ws, query_id):
await ws.send(f'{{"type":"complete","id":"{query_id}","payload":null}}')
Expand Down Expand Up @@ -165,6 +167,26 @@ async def wait_connection_terminate(ws):
assert json_result["type"] == "connection_terminate"


class PhoenixChannelServerHelper:
@staticmethod
async def send_close(ws):
await ws.send('{"event":"phx_close"}')

@staticmethod
async def send_connection_ack(ws):

# Line return for easy debugging
print("")

# Wait for init
result = await ws.recv()
json_result = json.loads(result)
assert json_result["event"] == "phx_join"

# Send ack
await ws.send('{"event":"phx_reply", "payload": {"status": "ok"}, "ref": 1}')


def get_server_handler(request):
"""Get the server handler.
Expand All @@ -181,7 +203,7 @@ def get_server_handler(request):
async def default_server_handler(ws, path):

try:
await WebSocketServer.send_connection_ack(ws)
await WebSocketServerHelper.send_connection_ack(ws)
query_id = 1

for answer in answers:
Expand All @@ -195,10 +217,10 @@ async def default_server_handler(ws, path):
formatted_answer = answer

await ws.send(formatted_answer)
await WebSocketServer.send_complete(ws, query_id)
await WebSocketServerHelper.send_complete(ws, query_id)
query_id += 1

await WebSocketServer.wait_connection_terminate(ws)
await WebSocketServerHelper.wait_connection_terminate(ws)
await ws.wait_closed()
except ConnectionClosed:
pass
Expand Down
8 changes: 4 additions & 4 deletions tests/test_async_client_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from gql import Client, gql
from gql.transport.websockets import WebsocketsTransport

from .conftest import MS, WebSocketServer
from .conftest import MS, WebSocketServerHelper
from .starwars.schema import StarWarsIntrospection, StarWarsSchema, StarWarsTypeDef

starwars_expected_one = {
Expand All @@ -25,7 +25,7 @@


async def server_starwars(ws, path):
await WebSocketServer.send_connection_ack(ws)
await WebSocketServerHelper.send_connection_ack(ws)

try:
await ws.recv()
Expand All @@ -42,8 +42,8 @@ async def server_starwars(ws, path):
await ws.send(data)
await asyncio.sleep(2 * MS)

await WebSocketServer.send_complete(ws, 1)
await WebSocketServer.wait_connection_terminate(ws)
await WebSocketServerHelper.send_complete(ws, 1)
await WebSocketServerHelper.wait_connection_terminate(ws)

except websockets.exceptions.ConnectionClosedOK:
pass
Expand Down
Loading

0 comments on commit 706f789

Please sign in to comment.