Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding a new transport class to handle Phoenix channels #100

Merged
merged 25 commits into from
Sep 7, 2020
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
5fd3e67
New transport class to handle Phoenix channels
leruaa Jun 6, 2020
f085373
Adding forgotten typing hints
leruaa Jun 7, 2020
5e6018e
Fix _handle_answer args order
leruaa Jun 7, 2020
e1eccbc
Better handle exceptions
leruaa Jun 7, 2020
aeb18ec
Simulate complete messages in _send_stop_message
leruaa Jun 8, 2020
9c87006
Cancel heartbeat task in _close_coro
leruaa Jun 8, 2020
b635138
Remove useless close override
leruaa Jun 8, 2020
b7bd5a3
Set a correct ref in the heartbeat message
leruaa Jun 17, 2020
0cb606f
Merge branch 'master' into feature-phoenix-channel
leruaa Jun 17, 2020
4d6de48
Adding unit tests for PhoenixChannelWebsocketsTransport exceptions
leruaa Jun 17, 2020
76fa98a
Increase PhoenixChannelWebsocketsTransport unit tests coverage
leruaa Jun 18, 2020
dd6101f
Better handle a case when there are multiple errors in a query
leruaa Jun 23, 2020
7047c3e
Merge branch 'master' into feature-phoenix-channel
leruaa Jun 29, 2020
33d51e6
Fix mypy errors
leruaa Jun 29, 2020
a571bd2
Merge branch 'master' into feature-phoenix-channel
leruaa Jun 29, 2020
c742487
Merge branch 'master' into feature-phoenix-channel
leruaa Jul 11, 2020
f0f62be
Update WebSocketServer to WebSocketServerHelper in test_websocket_non…
leruaa Jul 11, 2020
28ad41d
Adding doc on PhoenixChannelWebsocketsTransport
leruaa Jul 11, 2020
c10b851
Merge branch 'master' into feature-phoenix-channel
leruaa Aug 15, 2020
bd72e79
DSL: Fixed bug where a nested GraphQLInputObjectType is causing infin…
JBrVJxsc Aug 17, 2020
1a2dcec
Fix race condition in websocket transport close (#133)
leszekhanusz Aug 27, 2020
29f7f2b
add the data property in TransportQueryError (#136)
leszekhanusz Aug 28, 2020
3c82d1e
Fix running execute and subscribe of client in a Thread (#135)
leszekhanusz Sep 7, 2020
b4ab941
Allow to import PhoenixChannelWebsocketsTransport directly from gql &…
leszekhanusz Sep 7, 2020
bec8e66
Merge branch 'master' into feature-phoenix-channel
leszekhanusz Sep 7, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions gql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from .client import Client
from .gql import gql
from .transport.aiohttp import AIOHTTPTransport
from .transport.phoenix_channel_websockets import PhoenixChannelWebsocketsTransport
from .transport.requests import RequestsHTTPTransport
from .transport.websockets import WebsocketsTransport

__all__ = [
"gql",
"AIOHTTPTransport",
"Client",
"PhoenixChannelWebsocketsTransport",
"RequestsHTTPTransport",
"WebsocketsTransport",
]
30 changes: 24 additions & 6 deletions gql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,13 @@ def execute(self, document: DocumentNode, *args, **kwargs) -> Dict:

if isinstance(self.transport, AsyncTransport):

loop = asyncio.get_event_loop()
# Get the current asyncio event loop
# Or create a new event loop if there isn't one (in a new Thread)
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

assert not loop.is_running(), (
"Cannot run client.execute(query) if an asyncio loop is running."
Expand Down Expand Up @@ -146,9 +152,15 @@ def subscribe(
We need an async transport for this functionality.
"""

async_generator = self.subscribe_async(document, *args, **kwargs)
# Get the current asyncio event loop
# Or create a new event loop if there isn't one (in a new Thread)
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

loop = asyncio.get_event_loop()
async_generator = self.subscribe_async(document, *args, **kwargs)

assert not loop.is_running(), (
"Cannot run client.subscribe(query) if an asyncio loop is running."
Expand Down Expand Up @@ -240,7 +252,9 @@ def execute(self, document: DocumentNode, *args, **kwargs) -> Dict:

# Raise an error if an error is returned in the ExecutionResult object
if result.errors:
raise TransportQueryError(str(result.errors[0]), errors=result.errors)
raise TransportQueryError(
str(result.errors[0]), errors=result.errors, data=result.data
)

assert (
result.data is not None
Expand Down Expand Up @@ -315,7 +329,9 @@ async def subscribe(

# Raise an error if an error is returned in the ExecutionResult object
if result.errors:
raise TransportQueryError(str(result.errors[0]), errors=result.errors)
raise TransportQueryError(
str(result.errors[0]), errors=result.errors, data=result.data
)

elif result.data is not None:
yield result.data
Expand All @@ -340,7 +356,9 @@ async def execute(self, document: DocumentNode, *args, **kwargs) -> Dict:

# Raise an error if an error is returned in the ExecutionResult object
if result.errors:
raise TransportQueryError(str(result.errors[0]), errors=result.errors)
raise TransportQueryError(
str(result.errors[0]), errors=result.errors, data=result.data
)

assert (
result.data is not None
Expand Down
21 changes: 14 additions & 7 deletions gql/dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def args(self, **kwargs):
arg = self.field.args.get(name)
if not arg:
raise KeyError(f"Argument {name} does not exist in {self.field}.")
arg_type_serializer = get_arg_serializer(arg.type)
arg_type_serializer = get_arg_serializer(arg.type, known_serializers=dict())
serialized_value = arg_type_serializer(value)
added_args.append(
ArgumentNode(name=NameNode(value=name), value=serialized_value)
Expand Down Expand Up @@ -151,21 +151,28 @@ def serialize_list(serializer, list_values):
return ListValueNode(values=FrozenList(serializer(v) for v in list_values))


def get_arg_serializer(arg_type):
def get_arg_serializer(arg_type, known_serializers):
if isinstance(arg_type, GraphQLNonNull):
return get_arg_serializer(arg_type.of_type)
return get_arg_serializer(arg_type.of_type, known_serializers)
if isinstance(arg_type, GraphQLInputField):
return get_arg_serializer(arg_type.type)
return get_arg_serializer(arg_type.type, known_serializers)
if isinstance(arg_type, GraphQLInputObjectType):
serializers = {k: get_arg_serializer(v) for k, v in arg_type.fields.items()}
return lambda value: ObjectValueNode(
if arg_type in known_serializers:
return known_serializers[arg_type]
known_serializers[arg_type] = None
serializers = {
k: get_arg_serializer(v, known_serializers)
for k, v in arg_type.fields.items()
}
known_serializers[arg_type] = lambda value: ObjectValueNode(
fields=FrozenList(
ObjectFieldNode(name=NameNode(value=k), value=serializers[k](v))
for k, v in value.items()
)
)
return known_serializers[arg_type]
if isinstance(arg_type, GraphQLList):
inner_serializer = get_arg_serializer(arg_type.of_type)
inner_serializer = get_arg_serializer(arg_type.of_type, known_serializers)
return partial(serialize_list, inner_serializer)
if isinstance(arg_type, GraphQLEnumType):
return lambda value: EnumValueNode(value=arg_type.serialize(value))
Expand Down
2 changes: 2 additions & 0 deletions gql/transport/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@ def __init__(
msg: str,
query_id: Optional[int] = None,
errors: Optional[List[Any]] = None,
data: Optional[Any] = None,
):
super().__init__(msg)
self.query_id = query_id
self.errors = errors
self.data = data


class TransportClosed(TransportError):
Expand Down
250 changes: 250 additions & 0 deletions gql/transport/phoenix_channel_websockets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
import asyncio
import json
from typing import Dict, Optional, Tuple

from graphql import DocumentNode, ExecutionResult, print_ast
from websockets.exceptions import ConnectionClosed

from .exceptions import (
TransportProtocolError,
TransportQueryError,
TransportServerError,
)
from .websockets import WebsocketsTransport


class PhoenixChannelWebsocketsTransport(WebsocketsTransport):
def __init__(
self, channel_name: str, heartbeat_interval: float = 30, *args, **kwargs
) -> None:
self.channel_name = channel_name
self.heartbeat_interval = heartbeat_interval
self.subscription_ids_to_query_ids: Dict[str, int] = {}
super(PhoenixChannelWebsocketsTransport, self).__init__(*args, **kwargs)
"""Initialize the transport with the given request parameters.

:param channel_name Channel on the server this transport will join
:param heartbeat_interval Interval in second between each heartbeat messages
sent by the client
"""

async def _send_init_message_and_wait_ack(self) -> None:
"""Join the specified channel and wait for the connection ACK.

If the answer is not a connection_ack message, we will return an Exception.
"""

query_id = self.next_query_id
self.next_query_id += 1

init_message = json.dumps(
{
"topic": self.channel_name,
"event": "phx_join",
"payload": {},
"ref": query_id,
}
)
leszekhanusz marked this conversation as resolved.
Show resolved Hide resolved

await self._send(init_message)

# Wait for the connection_ack message or raise a TimeoutError
init_answer = await asyncio.wait_for(self._receive(), self.ack_timeout)

answer_type, answer_id, execution_result = self._parse_answer(init_answer)

if answer_type != "reply":
raise TransportProtocolError(
"Websocket server did not return a connection ack"
)

async def heartbeat_coro():
while True:
await asyncio.sleep(self.heartbeat_interval)
try:
query_id = self.next_query_id
self.next_query_id += 1

await self._send(
json.dumps(
{
"topic": "phoenix",
"event": "heartbeat",
"payload": {},
"ref": query_id,
}
)
)
except ConnectionClosed: # pragma: no cover
return

self.heartbeat_task = asyncio.ensure_future(heartbeat_coro())

async def _send_stop_message(self, query_id: int) -> None:
try:
await self.listeners[query_id].put(("complete", None))
except KeyError: # pragma: no cover
pass

async def _send_connection_terminate_message(self) -> None:
"""Send a phx_leave message to disconnect from the provided channel.
"""

query_id = self.next_query_id
self.next_query_id += 1

connection_terminate_message = json.dumps(
{
"topic": self.channel_name,
"event": "phx_leave",
"payload": {},
"ref": query_id,
}
)

await self._send(connection_terminate_message)

async def _send_query(
self,
document: DocumentNode,
variable_values: Optional[Dict[str, str]] = None,
operation_name: Optional[str] = None,
) -> int:
"""Send a query to the provided websocket connection.

We use an incremented id to reference the query.

Returns the used id for this query.
"""

query_id = self.next_query_id
self.next_query_id += 1

query_str = json.dumps(
{
"topic": self.channel_name,
"event": "doc",
"payload": {
"query": print_ast(document),
"variables": variable_values or {},
},
"ref": query_id,
}
)

await self._send(query_str)

return query_id

def _parse_answer(
self, answer: str
) -> Tuple[str, Optional[int], Optional[ExecutionResult]]:
"""Parse the answer received from the server

Returns a list consisting of:
- the answer_type (between:
'heartbeat', 'data', 'reply', 'error', 'close')
- the answer id (Integer) if received or None
- an execution Result if the answer_type is 'data' or None
"""

event: str = ""
answer_id: Optional[int] = None
answer_type: str = ""
execution_result: Optional[ExecutionResult] = None

try:
json_answer = json.loads(answer)

event = str(json_answer.get("event"))

if event == "subscription:data":
payload = json_answer.get("payload")

if not isinstance(payload, dict):
raise ValueError("payload is not a dict")

subscription_id = str(payload.get("subscriptionId"))
try:
answer_id = self.subscription_ids_to_query_ids[subscription_id]
except KeyError:
raise ValueError(
f"subscription '{subscription_id}' has not been registerd"
)

result = payload.get("result")

if not isinstance(result, dict):
raise ValueError("result is not a dict")

answer_type = "data"

execution_result = ExecutionResult(
errors=payload.get("errors"), data=result.get("data")
)

elif event == "phx_reply":
answer_id = int(json_answer.get("ref"))
payload = json_answer.get("payload")

if not isinstance(payload, dict):
raise ValueError("payload is not a dict")

status = str(payload.get("status"))

if status == "ok":

answer_type = "reply"
response = payload.get("response")

if isinstance(response, dict) and "subscriptionId" in response:
subscription_id = str(response.get("subscriptionId"))
self.subscription_ids_to_query_ids[subscription_id] = answer_id

elif status == "error":
response = payload.get("response")

if isinstance(response, dict):
if "errors" in response:
raise TransportQueryError(
str(response.get("errors")), query_id=answer_id
)
elif "reason" in response:
raise TransportQueryError(
str(response.get("reason")), query_id=answer_id
)
raise ValueError("reply error")

elif status == "timeout":
raise TransportQueryError("reply timeout", query_id=answer_id)

elif event == "phx_error":
raise TransportServerError("Server error")
elif event == "phx_close":
answer_type = "close"
else:
raise ValueError

except ValueError as e:
raise TransportProtocolError(
"Server did not return a GraphQL result"
) from e

return answer_type, answer_id, execution_result

async def _handle_answer(
self,
answer_type: str,
answer_id: Optional[int],
execution_result: Optional[ExecutionResult],
) -> None:
if answer_type == "close":
await self.close()
else:
await super()._handle_answer(answer_type, answer_id, execution_result)

async def _close_coro(self, e: Exception, clean_close: bool = True) -> None:
if self.heartbeat_task is not None:
self.heartbeat_task.cancel()

await super()._close_coro(e, clean_close)
Loading