Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

File support #126

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 37 additions & 10 deletions gql/transport/aiohttp.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from ssl import SSLContext
from typing import Any, AsyncGenerator, Dict, Optional, Union

Expand All @@ -8,6 +9,7 @@
from aiohttp.typedefs import LooseCookies, LooseHeaders
from graphql import DocumentNode, ExecutionResult, print_ast

from ..utils import extract_files
from .async_transport import AsyncTransport
from .exceptions import (
TransportAlreadyConnected,
Expand All @@ -33,7 +35,7 @@ def __init__(
auth: Optional[BasicAuth] = None,
ssl: Union[SSLContext, bool, Fingerprint] = False,
timeout: Optional[int] = None,
client_session_args: Dict[str, Any] = {},
client_session_args: Optional[Dict[str, Any]] = None,
) -> None:
"""Initialize the transport with the given aiohttp parameters.

Expand All @@ -51,7 +53,6 @@ def __init__(
self.ssl: Union[SSLContext, bool, Fingerprint] = ssl
self.timeout: Optional[int] = timeout
self.client_session_args = client_session_args

self.session: Optional[aiohttp.ClientSession] = None

async def connect(self) -> None:
Expand All @@ -76,7 +77,8 @@ async def connect(self) -> None:
)

# Adding custom parameters passed from init
client_session_args.update(self.client_session_args)
if self.client_session_args:
client_session_args.update(self.client_session_args) # type: ignore

self.session = aiohttp.ClientSession(**client_session_args)

Expand All @@ -93,7 +95,7 @@ async def execute(
document: DocumentNode,
variable_values: Optional[Dict[str, str]] = None,
operation_name: Optional[str] = None,
extra_args: Dict[str, Any] = {},
extra_args: Dict[str, Any] = None,
) -> ExecutionResult:
"""Execute the provided document AST against the configured remote server.
This uses the aiohttp library to perform a HTTP POST request asynchronously
Expand All @@ -103,21 +105,46 @@ async def execute(
"""

query_str = print_ast(document)

nulled_variable_values = None
files = None
if variable_values:
nulled_variable_values, files = extract_files(variable_values)

payload: Dict[str, Any] = {
"query": query_str,
}

if variable_values:
payload["variables"] = variable_values
if nulled_variable_values:
payload["variables"] = nulled_variable_values
if operation_name:
payload["operationName"] = operation_name

post_args = {
"json": payload,
}
if files:
data = aiohttp.FormData()

# header
file_map = {str(i): [path] for i, path in enumerate(files)}
# path is nested in a list because the spec allows multiple pointers
# to the same file. But we don't use that.
file_streams = {
str(i): files[path] for i, path in enumerate(files)
} # payload

data.add_field(
"operations", json.dumps(payload), content_type="application/json"
)
data.add_field("map", json.dumps(file_map), content_type="application/json")
data.add_fields(*file_streams.items())

post_args = {"data": data}

else:
post_args = {"json": payload} # type: ignore

# Pass post_args to aiohttp post method
post_args.update(extra_args)
if extra_args:
post_args.update(extra_args) # type: ignore

if self.session is None:
raise TransportClosed("Transport is not connected")
Expand Down
43 changes: 43 additions & 0 deletions gql/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""Utilities to manipulate several python objects."""

import io
from typing import Any, Dict, Tuple


# From this response in Stackoverflow
# http://stackoverflow.com/a/19053800/1072990
Expand All @@ -8,3 +11,43 @@ def to_camel_case(snake_str):
# We capitalize the first letter of each component except the first one
# with the 'title' method and join them together.
return components[0] + "".join(x.title() if x else "_" for x in components[1:])


def is_file_like(value: Any) -> bool:
"""Check if a value represents a file like object"""
return isinstance(value, io.IOBase)


def extract_files(variables: Dict) -> Tuple[Dict, Dict]:
files = {}

def recurse_extract(path, obj):
"""
recursively traverse obj, doing a deepcopy, but
replacing any file-like objects with nulls and
shunting the originals off to the side.
"""
nonlocal files
if isinstance(obj, list):
nulled_obj = []
for key, value in enumerate(obj):
value = recurse_extract(f"{path}.{key}", value)
nulled_obj.append(value)
return nulled_obj
elif isinstance(obj, dict):
nulled_obj = {}
for key, value in obj.items():
value = recurse_extract(f"{path}.{key}", value)
nulled_obj[key] = value
return nulled_obj
elif is_file_like(obj):
# extract obj from its parent and put it into files instead.
files[path] = obj
return None
else:
# base case: pass through unchanged
return obj

nulled_variables = recurse_extract("variables", variables)

return nulled_variables, files