diff --git a/README.md b/README.md index 92981ff1..9bdd7f76 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,7 @@ The main features of GQL are: * Possibility to [validate the queries locally](https://gql.readthedocs.io/en/latest/usage/validation.html) using a GraphQL schema provided locally or fetched from the backend using an instrospection query * Supports GraphQL queries, mutations and subscriptions * Supports [sync or async usage](https://gql.readthedocs.io/en/latest/async/index.html), [allowing concurrent requests](https://gql.readthedocs.io/en/latest/advanced/async_advanced_usage.html#async-advanced-usage) +* Supports [File uploads](https://gql.readthedocs.io/en/latest/usage/file_upload.html) ## Installation diff --git a/docs/transports/aiohttp.rst b/docs/transports/aiohttp.rst index cdca6f45..a54809cc 100644 --- a/docs/transports/aiohttp.rst +++ b/docs/transports/aiohttp.rst @@ -1,3 +1,5 @@ +.. _aiohttp_transport: + AIOHTTPTransport ================ diff --git a/docs/usage/file_upload.rst b/docs/usage/file_upload.rst new file mode 100644 index 00000000..d900df95 --- /dev/null +++ b/docs/usage/file_upload.rst @@ -0,0 +1,69 @@ +File uploads +============ + +GQL supports file uploads with the :ref:`aiohttp transport ` +using the `GraphQL multipart request spec`_. + +.. _GraphQL multipart request spec: https://github.com/jaydenseric/graphql-multipart-request-spec + +Single File +----------- + +In order to upload a single file, you need to: + +* set the file as a variable value in the mutation +* provide the opened file to the `variable_values` argument of `execute` +* set the `upload_files` argument to True + +.. code-block:: python + + transport = AIOHTTPTransport(url='YOUR_URL') + + client = Client(transport=sample_transport) + + query = gql(''' + mutation($file: Upload!) { + singleUpload(file: $file) { + id + } + } + ''') + + with open("YOUR_FILE_PATH", "rb") as f: + + params = {"file": f} + + result = client.execute( + query, variable_values=params, upload_files=True + ) + +File list +--------- + +It is also possible to upload multiple files using a list. + +.. code-block:: python + + transport = AIOHTTPTransport(url='YOUR_URL') + + client = Client(transport=sample_transport) + + query = gql(''' + mutation($files: [Upload!]!) { + multipleUpload(files: $files) { + id + } + } + ''') + + f1 = open("YOUR_FILE_PATH_1", "rb") + f2 = open("YOUR_FILE_PATH_1", "rb") + + params = {"files": [f1, f2]} + + result = client.execute( + query, variable_values=params, upload_files=True + ) + + f1.close() + f2.close() diff --git a/docs/usage/index.rst b/docs/usage/index.rst index 2d5d5fd3..a7dd4d56 100644 --- a/docs/usage/index.rst +++ b/docs/usage/index.rst @@ -9,3 +9,4 @@ Usage subscriptions variables headers + file_upload diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index 2ae83999..f17d3f5b 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -1,3 +1,5 @@ +import json +import logging from ssl import SSLContext from typing import Any, AsyncGenerator, Dict, Optional, Union @@ -8,6 +10,7 @@ from aiohttp.typedefs import LooseCookies, LooseHeaders from graphql import DocumentNode, ExecutionResult, print_ast +from ..utils import extract_files from .async_transport import AsyncTransport from .exceptions import ( TransportAlreadyConnected, @@ -16,6 +19,8 @@ TransportServerError, ) +log = logging.getLogger(__name__) + class AIOHTTPTransport(AsyncTransport): """:ref:`Async Transport ` to execute GraphQL queries @@ -32,7 +37,7 @@ def __init__( auth: Optional[BasicAuth] = None, ssl: Union[SSLContext, bool, Fingerprint] = False, timeout: Optional[int] = None, - client_session_args: Dict[str, Any] = {}, + client_session_args: Optional[Dict[str, Any]] = None, ) -> None: """Initialize the transport with the given aiohttp parameters. @@ -54,7 +59,6 @@ def __init__( self.ssl: Union[SSLContext, bool, Fingerprint] = ssl self.timeout: Optional[int] = timeout self.client_session_args = client_session_args - self.session: Optional[aiohttp.ClientSession] = None async def connect(self) -> None: @@ -81,7 +85,8 @@ async def connect(self) -> None: ) # Adding custom parameters passed from init - client_session_args.update(self.client_session_args) + if self.client_session_args: + client_session_args.update(self.client_session_args) # type: ignore self.session = aiohttp.ClientSession(**client_session_args) @@ -104,7 +109,8 @@ async def execute( document: DocumentNode, variable_values: Optional[Dict[str, str]] = None, operation_name: Optional[str] = None, - extra_args: Dict[str, Any] = {}, + extra_args: Dict[str, Any] = None, + upload_files: bool = False, ) -> ExecutionResult: """Execute the provided document AST against the configured remote server using the current session. @@ -118,25 +124,70 @@ async def execute( :param variables_values: An optional Dict of variable values :param operation_name: An optional Operation name for the request :param extra_args: additional arguments to send to the aiohttp post method + :param upload_files: Set to True if you want to put files in the variable values :returns: an ExecutionResult object. """ query_str = print_ast(document) + payload: Dict[str, Any] = { "query": query_str, } - if variable_values: - payload["variables"] = variable_values if operation_name: payload["operationName"] = operation_name - post_args = { - "json": payload, - } + if upload_files: + + # If the upload_files flag is set, then we need variable_values + assert variable_values is not None + + # If we upload files, we will extract the files present in the + # variable_values dict and replace them by null values + nulled_variable_values, files = extract_files(variable_values) + + # Save the nulled variable values in the payload + payload["variables"] = nulled_variable_values + + # Prepare aiohttp to send multipart-encoded data + data = aiohttp.FormData() + + # Generate the file map + # path is nested in a list because the spec allows multiple pointers + # to the same file. But we don't support that. + # Will generate something like {"0": ["variables.file"]} + file_map = {str(i): [path] for i, path in enumerate(files)} + + # Enumerate the file streams + # Will generate something like {'0': <_io.BufferedReader ...>} + file_streams = {str(i): files[path] for i, path in enumerate(files)} + + # Add the payload to the operations field + operations_str = json.dumps(payload) + log.debug("operations %s", operations_str) + data.add_field( + "operations", operations_str, content_type="application/json" + ) + + # Add the file map field + file_map_str = json.dumps(file_map) + log.debug("file_map %s", file_map_str) + data.add_field("map", file_map_str, content_type="application/json") + + # Add the extracted files as remaining fields + data.add_fields(*file_streams.items()) + + post_args: Dict[str, Any] = {"data": data} + + else: + if variable_values: + payload["variables"] = variable_values + + post_args = {"json": payload} # Pass post_args to aiohttp post method - post_args.update(extra_args) + if extra_args: + post_args.update(extra_args) if self.session is None: raise TransportClosed("Transport is not connected") diff --git a/gql/utils.py b/gql/utils.py index 8f47d97d..ce0318b0 100644 --- a/gql/utils.py +++ b/gql/utils.py @@ -1,5 +1,8 @@ """Utilities to manipulate several python objects.""" +import io +from typing import Any, Dict, Tuple + # From this response in Stackoverflow # http://stackoverflow.com/a/19053800/1072990 @@ -8,3 +11,43 @@ def to_camel_case(snake_str): # We capitalize the first letter of each component except the first one # with the 'title' method and join them together. return components[0] + "".join(x.title() if x else "_" for x in components[1:]) + + +def is_file_like(value: Any) -> bool: + """Check if a value represents a file like object""" + return isinstance(value, io.IOBase) + + +def extract_files(variables: Dict) -> Tuple[Dict, Dict]: + files = {} + + def recurse_extract(path, obj): + """ + recursively traverse obj, doing a deepcopy, but + replacing any file-like objects with nulls and + shunting the originals off to the side. + """ + nonlocal files + if isinstance(obj, list): + nulled_obj = [] + for key, value in enumerate(obj): + value = recurse_extract(f"{path}.{key}", value) + nulled_obj.append(value) + return nulled_obj + elif isinstance(obj, dict): + nulled_obj = {} + for key, value in obj.items(): + value = recurse_extract(f"{path}.{key}", value) + nulled_obj[key] = value + return nulled_obj + elif is_file_like(obj): + # extract obj from its parent and put it into files instead. + files[path] = obj + return None + else: + # base case: pass through unchanged + return obj + + nulled_variables = recurse_extract("variables", variables) + + return nulled_variables, files diff --git a/tests/conftest.py b/tests/conftest.py index c2edc236..c2a15605 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,8 +4,10 @@ import os import pathlib import ssl +import tempfile import types from concurrent.futures import ThreadPoolExecutor +from typing import Union import pytest import websockets @@ -187,6 +189,35 @@ async def send_connection_ack(ws): await ws.send('{"event":"phx_reply", "payload": {"status": "ok"}, "ref": 1}') +class TemporaryFile: + """Class used to generate temporary files for the tests""" + + def __init__(self, content: Union[str, bytearray]): + + mode = "w" if isinstance(content, str) else "wb" + + # We need to set the newline to '' so that the line returns + # are not replaced by '\r\n' on windows + newline = "" if isinstance(content, str) else None + + self.file = tempfile.NamedTemporaryFile( + mode=mode, newline=newline, delete=False + ) + + with self.file as f: + f.write(content) + + @property + def filename(self): + return self.file.name + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + os.unlink(self.filename) + + def get_server_handler(request): """Get the server handler. diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index 0e97655f..8f39319f 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -11,6 +11,8 @@ TransportServerError, ) +from .conftest import TemporaryFile + query1_str = """ query getContinents { continents { @@ -321,3 +323,360 @@ def test_code(): pass await run_sync_test(event_loop, server, test_code) + + +file_upload_server_answer = '{"data":{"success":true}}' + +file_upload_mutation_1 = """ + mutation($file: Upload!) { + uploadFile(input:{other_var:$other_var, file:$file}) { + success + } + } +""" + +file_upload_mutation_1_operations = ( + '{"query": "mutation ($file: Upload!) {\\n uploadFile(input: {other_var: ' + '$other_var, file: $file}) {\\n success\\n }\\n}\\n", "variables": ' + '{"file": null, "other_var": 42}}' +) + +file_upload_mutation_1_map = '{"0": ["variables.file"]}' + +file_1_content = """ +This is a test file +This file will be sent in the GraphQL mutation +""" + + +async def single_upload_handler(request): + + reader = await request.multipart() + + field_0 = await reader.next() + assert field_0.name == "operations" + field_0_text = await field_0.text() + assert field_0_text == file_upload_mutation_1_operations + + field_1 = await reader.next() + assert field_1.name == "map" + field_1_text = await field_1.text() + assert field_1_text == file_upload_mutation_1_map + + field_2 = await reader.next() + assert field_2.name == "0" + field_2_text = await field_2.text() + assert field_2_text == file_1_content + + field_3 = await reader.next() + assert field_3 is None + + return web.Response(text=file_upload_server_answer, content_type="application/json") + + +@pytest.mark.asyncio +async def test_aiohttp_file_upload(event_loop, aiohttp_server): + app = web.Application() + app.router.add_route("POST", "/", single_upload_handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + sample_transport = AIOHTTPTransport(url=url, timeout=10) + + with TemporaryFile(file_1_content) as test_file: + + async with Client(transport=sample_transport,) as session: + + query = gql(file_upload_mutation_1) + + file_path = test_file.filename + + with open(file_path, "rb") as f: + + params = {"file": f, "other_var": 42} + + # Execute query asynchronously + result = await session.execute( + query, variable_values=params, upload_files=True + ) + + success = result["success"] + + assert success + + +@pytest.mark.asyncio +async def test_aiohttp_file_upload_without_session( + event_loop, aiohttp_server, run_sync_test +): + + app = web.Application() + app.router.add_route("POST", "/", single_upload_handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + sample_transport = AIOHTTPTransport(url=url, timeout=10) + + with TemporaryFile(file_1_content) as test_file: + + client = Client(transport=sample_transport,) + + query = gql(file_upload_mutation_1) + + file_path = test_file.filename + + with open(file_path, "rb") as f: + + params = {"file": f, "other_var": 42} + + result = client.execute( + query, variable_values=params, upload_files=True + ) + + success = result["success"] + + assert success + + await run_sync_test(event_loop, server, test_code) + + +# This is a sample binary file content containing all possible byte values +binary_file_content = bytes(range(0, 256)) + + +async def binary_upload_handler(request): + + reader = await request.multipart() + + field_0 = await reader.next() + assert field_0.name == "operations" + field_0_text = await field_0.text() + assert field_0_text == file_upload_mutation_1_operations + + field_1 = await reader.next() + assert field_1.name == "map" + field_1_text = await field_1.text() + assert field_1_text == file_upload_mutation_1_map + + field_2 = await reader.next() + assert field_2.name == "0" + field_2_binary = await field_2.read() + assert field_2_binary == binary_file_content + + field_3 = await reader.next() + assert field_3 is None + + return web.Response(text=file_upload_server_answer, content_type="application/json") + + +@pytest.mark.asyncio +async def test_aiohttp_binary_file_upload(event_loop, aiohttp_server): + app = web.Application() + app.router.add_route("POST", "/", binary_upload_handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + sample_transport = AIOHTTPTransport(url=url, timeout=10) + + with TemporaryFile(binary_file_content) as test_file: + + async with Client(transport=sample_transport,) as session: + + query = gql(file_upload_mutation_1) + + file_path = test_file.filename + + with open(file_path, "rb") as f: + + params = {"file": f, "other_var": 42} + + # Execute query asynchronously + result = await session.execute( + query, variable_values=params, upload_files=True + ) + + success = result["success"] + + assert success + + +file_upload_mutation_2 = """ + mutation($file1: Upload!, $file2: Upload!) { + uploadFile(input:{file1:$file, file2:$file}) { + success + } + } +""" + +file_upload_mutation_2_operations = ( + '{"query": "mutation ($file1: Upload!, $file2: Upload!) {\\n ' + 'uploadFile(input: {file1: $file, file2: $file}) {\\n success\\n }\\n}\\n", ' + '"variables": {"file1": null, "file2": null}}' +) + +file_upload_mutation_2_map = '{"0": ["variables.file1"], "1": ["variables.file2"]}' + +file_2_content = """ +This is a second test file +This file will also be sent in the GraphQL mutation +""" + + +@pytest.mark.asyncio +async def test_aiohttp_file_upload_two_files(event_loop, aiohttp_server): + async def handler(request): + + reader = await request.multipart() + + field_0 = await reader.next() + assert field_0.name == "operations" + field_0_text = await field_0.text() + assert field_0_text == file_upload_mutation_2_operations + + field_1 = await reader.next() + assert field_1.name == "map" + field_1_text = await field_1.text() + assert field_1_text == file_upload_mutation_2_map + + field_2 = await reader.next() + assert field_2.name == "0" + field_2_text = await field_2.text() + assert field_2_text == file_1_content + + field_3 = await reader.next() + assert field_3.name == "1" + field_3_text = await field_3.text() + assert field_3_text == file_2_content + + field_4 = await reader.next() + assert field_4 is None + + return web.Response( + text=file_upload_server_answer, content_type="application/json" + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + sample_transport = AIOHTTPTransport(url=url, timeout=10) + + with TemporaryFile(file_1_content) as test_file_1: + with TemporaryFile(file_2_content) as test_file_2: + + async with Client(transport=sample_transport,) as session: + + query = gql(file_upload_mutation_2) + + file_path_1 = test_file_1.filename + file_path_2 = test_file_2.filename + + f1 = open(file_path_1, "rb") + f2 = open(file_path_2, "rb") + + params = { + "file1": f1, + "file2": f2, + } + + result = await session.execute( + query, variable_values=params, upload_files=True + ) + + f1.close() + f2.close() + + success = result["success"] + + assert success + + +file_upload_mutation_3 = """ + mutation($files: [Upload!]!) { + uploadFiles(input:{files:$files}) { + success + } + } +""" + +file_upload_mutation_3_operations = ( + '{"query": "mutation ($files: [Upload!]!) {\\n uploadFiles(input: {files: $files})' + ' {\\n success\\n }\\n}\\n", "variables": {"files": [null, null]}}' +) + +file_upload_mutation_3_map = '{"0": ["variables.files.0"], "1": ["variables.files.1"]}' + + +@pytest.mark.asyncio +async def test_aiohttp_file_upload_list_of_two_files(event_loop, aiohttp_server): + async def handler(request): + + reader = await request.multipart() + + field_0 = await reader.next() + assert field_0.name == "operations" + field_0_text = await field_0.text() + assert field_0_text == file_upload_mutation_3_operations + + field_1 = await reader.next() + assert field_1.name == "map" + field_1_text = await field_1.text() + assert field_1_text == file_upload_mutation_3_map + + field_2 = await reader.next() + assert field_2.name == "0" + field_2_text = await field_2.text() + assert field_2_text == file_1_content + + field_3 = await reader.next() + assert field_3.name == "1" + field_3_text = await field_3.text() + assert field_3_text == file_2_content + + field_4 = await reader.next() + assert field_4 is None + + return web.Response( + text=file_upload_server_answer, content_type="application/json" + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + sample_transport = AIOHTTPTransport(url=url, timeout=10) + + with TemporaryFile(file_1_content) as test_file_1: + with TemporaryFile(file_2_content) as test_file_2: + + async with Client(transport=sample_transport,) as session: + + query = gql(file_upload_mutation_3) + + file_path_1 = test_file_1.filename + file_path_2 = test_file_2.filename + + f1 = open(file_path_1, "rb") + f2 = open(file_path_2, "rb") + + params = {"files": [f1, f2]} + + # Execute query asynchronously + result = await session.execute( + query, variable_values=params, upload_files=True + ) + + f1.close() + f2.close() + + success = result["success"] + + assert success