diff --git a/aries_cloudagent/utils/http.py b/aries_cloudagent/utils/http.py index 8e780338c3..d17fe234f2 100644 --- a/aries_cloudagent/utils/http.py +++ b/aries_cloudagent/utils/http.py @@ -1,8 +1,16 @@ """HTTP utility methods.""" import asyncio - -from aiohttp import BaseConnector, ClientError, ClientResponse, ClientSession +import logging +import urllib.parse + +from aiohttp import ( + BaseConnector, + ClientError, + ClientResponse, + ClientSession, + FormData, +) from aiohttp.web import HTTPConflict from ..core.error import BaseError @@ -10,6 +18,9 @@ from .repeat import RepeatSequence +LOGGER = logging.getLogger(__name__) + + class FetchError(BaseError): """Error raised when an HTTP fetch fails.""" @@ -147,7 +158,6 @@ async def put_file( """ (data_key, file_path) = [k for k in file_data.items()][0] - data = {**extra_data} limit = max_attempts if retry else 1 if not session: @@ -158,17 +168,51 @@ async def put_file( async for attempt in RepeatSequence(limit, interval, backoff): try: async with attempt.timeout(request_timeout): - with open(file_path, "rb") as f: - data[data_key] = f - response: ClientResponse = await session.put(url, data=data) - if (response.status < 200 or response.status >= 300) and ( - response.status != HTTPConflict.status_code - ): - raise ClientError( - f"Bad response from server: {response.status}, " - f"{response.reason}" - ) + formdata = FormData() + try: + fp = open(file_path, "rb") + except OSError as e: + raise PutError("Error opening file for upload") from e + if extra_data: + for k, v in extra_data.items(): + formdata.add_field(k, v) + formdata.add_field( + data_key, fp, content_type="application/octet-stream" + ) + response: ClientResponse = await session.put( + url, data=formdata, allow_redirects=False + ) + if ( + # redirect codes + response.status in (301, 302, 303, 307, 308) + and not attempt.final + ): + # NOTE: a redirect counts as another upload attempt + to_url = response.headers.get("Location") + if not to_url: + raise PutError("Redirect missing target URL") + try: + parsed_to = urllib.parse.urlsplit(to_url) + parsed_from = urllib.parse.urlsplit(url) + except ValueError: + raise PutError("Invalid redirect URL") + if parsed_to.hostname != parsed_from.hostname: + raise PutError("Redirect denied: hostname mismatch") + url = to_url + LOGGER.info("Upload redirect: %s", to_url) + elif (response.status < 200 or response.status >= 300) and ( + response.status != HTTPConflict.status_code + ): + raise ClientError( + f"Bad response from server: {response.status}, " + f"{response.reason}" + ) + else: return await (response.json() if json else response.text()) except (ClientError, asyncio.TimeoutError) as e: + if isinstance(e, ClientError): + LOGGER.warning("Upload error: %s", e) + else: + LOGGER.warning("Upload error: request timed out") if attempt.final: - raise PutError("Exceeded maximum put attempts") from e + raise PutError("Exceeded maximum upload attempts") from e diff --git a/aries_cloudagent/utils/tests/test_http.py b/aries_cloudagent/utils/tests/test_http.py index e760769958..6e2ff84adb 100644 --- a/aries_cloudagent/utils/tests/test_http.py +++ b/aries_cloudagent/utils/tests/test_http.py @@ -1,14 +1,33 @@ +import os +import tempfile + from aiohttp import web -from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop -from asynctest import mock as async_mock, mock_open +from aiohttp.test_utils import AioHTTPTestCase from ..http import fetch, fetch_stream, FetchError, put_file, PutError +class TempFile: + def __init__(self): + self.name = None + + def __enter__(self): + file = tempfile.NamedTemporaryFile(delete=False) + file.write(b"test") + file.close() + self.name = file.name + return self.name + + def __exit__(self, *args): + if self.name: + os.unlink(self.name) + + class TestTransportUtils(AioHTTPTestCase): async def setUpAsync(self): self.fail_calls = 0 self.succeed_calls = 0 + self.redirects = 0 await super().setUpAsync() async def get_application(self): @@ -19,12 +38,15 @@ async def get_application(self): web.get("/succeed", self.succeed_route), web.put("/fail", self.fail_route), web.put("/succeed", self.succeed_route), + web.put("/redirect", self.redirect_route), ] ) return app async def fail_route(self, request): self.fail_calls += 1 + # avoid aiohttp test server issue: https://github.com/aio-libs/aiohttp/issues/3968 + await request.read() raise web.HTTPForbidden() async def succeed_route(self, request): @@ -32,6 +54,14 @@ async def succeed_route(self, request): ret = web.json_response([True]) return ret + async def redirect_route(self, request): + if self.redirects > 0: + self.redirects -= 1 + # avoid aiohttp test server issue: https://github.com/aio-libs/aiohttp/issues/3968 + await request.read() + raise web.HTTPRedirection(f"http://localhost:{self.server.port}/success") + return await self.succeed_route(request) + async def test_fetch_stream(self): server_addr = f"http://localhost:{self.server.port}" stream = await fetch_stream( @@ -84,40 +114,55 @@ async def test_fetch_fail(self): ) assert self.fail_calls == 2 - async def test_put_file(self): + async def test_put_file_with_session(self): server_addr = f"http://localhost:{self.server.port}" - with async_mock.patch("builtins.open", mock_open(read_data="data")): + with TempFile() as tails: result = await put_file( f"{server_addr}/succeed", - {"tails": "/tmp/dummy/path"}, + {"tails": tails}, {"genesis": "..."}, session=self.client.session, json=True, ) - assert result == [1] + assert result == [True] assert self.succeed_calls == 1 async def test_put_file_default_client(self): server_addr = f"http://localhost:{self.server.port}" - with async_mock.patch("builtins.open", mock_open(read_data="data")): + with TempFile() as tails: result = await put_file( f"{server_addr}/succeed", - {"tails": "/tmp/dummy/path"}, + {"tails": tails}, {"genesis": "..."}, json=True, ) - assert result == [1] + assert result == [True] assert self.succeed_calls == 1 async def test_put_file_fail(self): server_addr = f"http://localhost:{self.server.port}" - with async_mock.patch("builtins.open", mock_open(read_data="data")): + with TempFile() as tails: with self.assertRaises(PutError): - result = await put_file( + _ = await put_file( f"{server_addr}/fail", - {"tails": "/tmp/dummy/path"}, + {"tails": tails}, {"genesis": "..."}, max_attempts=2, json=True, ) assert self.fail_calls == 2 + + async def test_put_file_redirect(self): + server_addr = f"http://localhost:{self.server.port}" + self.redirects = 1 + with TempFile() as tails: + result = await put_file( + f"{server_addr}/redirect", + {"tails": tails}, + {"genesis": "..."}, + max_attempts=2, + json=True, + ) + assert result == [True] + assert self.succeed_calls == 1 + assert self.redirects == 0