diff --git a/README.rst b/README.rst index 0dce5de7d09..25153c75984 100644 --- a/README.rst +++ b/README.rst @@ -29,13 +29,13 @@ To retrieve something from the web:: def get_body(url): response = yield from aiohttp.request('GET', url) - return (yield from response.read_and_close()) + return (yield from response.read()) You can use the get command like this anywhere in your ``asyncio`` powered program:: response = yield from aiohttp.request('GET', 'http://python.org') - body = yield from response.read_and_close() + body = yield from response.read() print(body) The signature of request is the following:: diff --git a/aiohttp/client.py b/aiohttp/client.py index ee238ef328a..dae2e092396 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -20,6 +20,7 @@ import warnings import aiohttp +from .helpers import parse_mimetype from .log import client_log from .multidict import CaseInsensitiveMultiDict, MultiDict, MutableMultiDict @@ -86,7 +87,7 @@ def request(method, url, *, >>> resp = yield from aiohttp.request('GET', 'http://python.org/') >>> resp - >>> data = yield from resp.read_and_close() + >>> data = yield from resp.read() """ redirects = 0 @@ -674,6 +675,10 @@ def close(self, force=False): self._writer = None self._writer_wr = None + @asyncio.coroutine + def release(self): + yield from self.read() + @asyncio.coroutine def wait_for_close(self): if self._writer is not None: @@ -697,7 +702,10 @@ def read(self, decode=False): buf.append((chunk, size)) total += size except aiohttp.EofStream: - pass + self.close() + except: + self.close(True) + raise self._content = bytearray(total) @@ -710,24 +718,40 @@ def read(self, decode=False): data = self._content if decode: - ct = self.headers.get('CONTENT-TYPE', '').lower() - if ct == 'application/json': - data = json.loads(data.decode('utf-8')) + warnings.warn( + '.read(True) is deprecated. use .json() instead', + DeprecationWarning + ) + return (yield from self.json()) return data @asyncio.coroutine def read_and_close(self, decode=False): """Read response payload and then close response.""" - try: - payload = yield from self.read(decode) - except: - self.close(True) - raise - else: - self.close() + warnings.warn( + 'read_and_close is deprecated, use .read() instead', + DeprecationWarning + ) + return (yield from self.read(decode)) + + @asyncio.coroutine + def json(self, *, encoding=None): + """Reads and decodes JSON response.""" + if self._content is None: + yield from self.read() + + ctype = self.headers.get('CONTENT-TYPE', '').lower() + mtype, stype, _, params = parse_mimetype(ctype) + if not (mtype == 'application' or stype == 'json'): + client_log.warning( + 'Attempt to decode JSON with unexpected mimetype: %s', ctype) + + if not self._content.strip(): + return None - return payload + encoding = encoding or params.get('charset', 'utf-8') + return json.loads(self._content.decode(encoding)) def str_to_bytes(s, encoding='utf-8'): diff --git a/aiohttp/helpers.py b/aiohttp/helpers.py new file mode 100644 index 00000000000..d5c2099f133 --- /dev/null +++ b/aiohttp/helpers.py @@ -0,0 +1,53 @@ +"""Various helper functions""" + + +def parse_mimetype(mimetype): + """Parses a MIME type into it components. + + :param str mimetype: MIME type + + :returns: 4 element tuple for MIME type, subtype, suffix and parameters + :rtype: tuple + + >>> parse_mimetype('*') + ('*', '*', '', {}) + + >>> parse_mimetype('application/json') + ('application', 'json', '', {}) + + >>> parse_mimetype('application/json; charset=utf-8') + ('application', 'json', '', {'charset': 'utf-8'}) + + >>> parse_mimetype('''application/json; + ... charset=utf-8;''') + ('application', 'json', '', {'charset': 'utf-8'}) + + >>> parse_mimetype('ApPlIcAtIoN/JSON;ChaRseT="UTF-8"') + ('application', 'json', '', {'charset': 'UTF-8'}) + + >>> parse_mimetype('application/rss+xml') + ('application', 'rss', 'xml', {}) + + >>> parse_mimetype('text/plain;base64') + ('text', 'plain', '', {'base64': ''}) + + """ + if not mimetype: + return '', '', '', {} + + parts = mimetype.split(';') + params = [] + for item in parts[1:]: + if not item: + continue + key, value = item.split('=', 2) if '=' in item else (item, '') + params.append((key.lower().strip(), value.strip(' "'))) + params = dict(params) + + fulltype = parts[0].strip().lower() + if fulltype == '*': + fulltype = '*/*' + mtype, stype = fulltype.split('/', 2) if '/' in fulltype else (fulltype, '') + stype, suffix = stype.split('+') if '+' in stype else (stype, '') + + return mtype, stype, suffix, params diff --git a/tests/test_client.py b/tests/test_client.py index f14b59501b7..39e4658f42e 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -81,28 +81,79 @@ def test_repr(self): repr(self.response)) def test_read_and_close(self): - self.response.read = unittest.mock.Mock() - self.response.read.return_value = asyncio.Future(loop=self.loop) - self.response.read.return_value.set_result(b'payload') + def side_effect(*args, **kwargs): + def second_call(*args, **kwargs): + raise aiohttp.EofStream + fut = asyncio.Future(loop=self.loop) + fut.set_result(b'payload') + content.read.side_effect = second_call + return fut + content = self.response.content = unittest.mock.Mock() + content.read.side_effect = side_effect self.response.close = unittest.mock.Mock() - res = self.loop.run_until_complete(self.response.read_and_close()) + res = self.loop.run_until_complete(self.response.read()) self.assertEqual(res, b'payload') - self.assertTrue(self.response.read.called) self.assertTrue(self.response.close.called) def test_read_and_close_with_error(self): - self.response.read = unittest.mock.Mock() - self.response.read.return_value = asyncio.Future(loop=self.loop) - self.response.read.return_value.set_exception(ValueError) + content = self.response.content = unittest.mock.Mock() + content.read.return_value = asyncio.Future(loop=self.loop) + content.read.return_value.set_exception(ValueError) self.response.close = unittest.mock.Mock() self.assertRaises( ValueError, - self.loop.run_until_complete, self.response.read_and_close()) - self.assertTrue(self.response.read.called) + self.loop.run_until_complete, self.response.read()) self.response.close.assert_called_with(True) + def test_release(self): + fut = asyncio.Future(loop=self.loop) + fut.set_exception(aiohttp.EofStream) + content = self.response.content = unittest.mock.Mock() + content.read.return_value = fut + self.response.close = unittest.mock.Mock() + + self.loop.run_until_complete(self.response.release()) + self.assertTrue(self.response.close.called) + + def test_json(self): + def side_effect(*args, **kwargs): + def second_call(*args, **kwargs): + raise aiohttp.EofStream + fut = asyncio.Future(loop=self.loop) + fut.set_result('{"тест": "пройден"}'.encode('cp1251')) + content.read.side_effect = second_call + return fut + self.response.headers = { + 'CONTENT-TYPE': 'application/json;charset=cp1251'} + content = self.response.content = unittest.mock.Mock() + content.read.side_effect = side_effect + self.response.close = unittest.mock.Mock() + + res = self.loop.run_until_complete(self.response.json()) + self.assertEqual(res, {'тест': 'пройден'}) + self.assertTrue(self.response.close.called) + + def test_json_override_encoding(self): + def side_effect(*args, **kwargs): + def second_call(*args, **kwargs): + raise aiohttp.EofStream + fut = asyncio.Future(loop=self.loop) + fut.set_result('{"тест": "пройден"}'.encode('cp1251')) + content.read.side_effect = second_call + return fut + self.response.headers = { + 'CONTENT-TYPE': 'application/json;charset=utf8'} + content = self.response.content = unittest.mock.Mock() + content.read.side_effect = side_effect + self.response.close = unittest.mock.Mock() + + res = self.loop.run_until_complete( + self.response.json(encoding='cp1251')) + self.assertEqual(res, {'тест': 'пройден'}) + self.assertTrue(self.response.close.called) + class ClientRequestTests(unittest.TestCase): diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index e8ef4d10173..309193ce1f6 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -251,7 +251,7 @@ def test_POST_DATA(self): loop=self.loop)) self.assertEqual(r.status, 200) - content = self.loop.run_until_complete(r.read(True)) + content = self.loop.run_until_complete(r.json()) self.assertEqual({'some': ['data']}, content['form']) self.assertEqual(r.status, 200) r.close() @@ -265,7 +265,7 @@ def test_POST_DATA_DEFLATE(self): loop=self.loop)) self.assertEqual(r.status, 200) - content = self.loop.run_until_complete(r.read(True)) + content = self.loop.run_until_complete(r.json()) self.assertEqual('deflate', content['compression']) self.assertEqual({'some': ['data']}, content['form']) self.assertEqual(r.status, 200) @@ -281,7 +281,7 @@ def test_POST_FILES(self): 'post', url, files={'some': f}, chunked=1024, headers={'Transfer-Encoding': 'chunked'}, loop=self.loop)) - content = self.loop.run_until_complete(r.read(True)) + content = self.loop.run_until_complete(r.json()) f.seek(0) filename = os.path.split(f.name)[-1] @@ -306,7 +306,7 @@ def test_POST_FILES_DEFLATE(self): chunked=1024, compress='deflate', loop=self.loop)) - content = self.loop.run_until_complete(r.read(True)) + content = self.loop.run_until_complete(r.json()) f.seek(0) filename = os.path.split(f.name)[-1] @@ -331,7 +331,7 @@ def test_POST_FILES_STR(self): client.request('post', url, files=[('some', f.read())], loop=self.loop)) - content = self.loop.run_until_complete(r.read(True)) + content = self.loop.run_until_complete(r.json()) f.seek(0) self.assertEqual(1, len(content['multipart-data'])) @@ -353,7 +353,7 @@ def test_POST_FILES_LIST(self): client.request('post', url, files=[('some', f)], loop=self.loop)) - content = self.loop.run_until_complete(r.read(True)) + content = self.loop.run_until_complete(r.json()) f.seek(0) filename = os.path.split(f.name)[-1] @@ -377,7 +377,7 @@ def test_POST_FILES_LIST_CT(self): client.request('post', url, loop=self.loop, files=[('some', f, 'text/plain')])) - content = self.loop.run_until_complete(r.read(True)) + content = self.loop.run_until_complete(r.json()) f.seek(0) filename = os.path.split(f.name)[-1] @@ -402,7 +402,7 @@ def test_POST_FILES_SINGLE(self): r = self.loop.run_until_complete( client.request('post', url, files=[f], loop=self.loop)) - content = self.loop.run_until_complete(r.read(True)) + content = self.loop.run_until_complete(r.json()) f.seek(0) filename = os.path.split(f.name)[-1] @@ -426,7 +426,7 @@ def test_POST_FILES_IO(self): r = self.loop.run_until_complete( client.request('post', url, files=[data], loop=self.loop)) - content = self.loop.run_until_complete(r.read(True)) + content = self.loop.run_until_complete(r.json()) self.assertEqual(1, len(content['multipart-data'])) self.assertEqual( @@ -446,7 +446,7 @@ def test_POST_FILES_WITH_DATA(self): client.request('post', url, loop=self.loop, data={'test': 'true'}, files={'some': f})) - content = self.loop.run_until_complete(r.read(True)) + content = self.loop.run_until_complete(r.json()) self.assertEqual(2, len(content['multipart-data'])) self.assertEqual( @@ -486,7 +486,7 @@ def stream(): 'post', url, data=stream(), headers={'Content-Length': str(len(data))}, loop=self.loop)) - content = self.loop.run_until_complete(r.read(True)) + content = self.loop.run_until_complete(r.json()) r.close() self.assertEqual(str(len(data)), @@ -500,7 +500,7 @@ def test_expect_continue(self): expect100=True, loop=self.loop)) self.assertEqual(r.status, 200) - content = self.loop.run_until_complete(r.read(True)) + content = self.loop.run_until_complete(r.json()) self.assertEqual('100-continue', content['headers']['Expect']) self.assertEqual(r.status, 200) r.close() @@ -554,7 +554,7 @@ def test_chunked(self): client.request('get', httpd.url('chunked'), loop=self.loop)) self.assertEqual(r.status, 200) self.assertEqual(r.headers.getone('TRANSFER-ENCODING'), 'chunked') - content = self.loop.run_until_complete(r.read(True)) + content = self.loop.run_until_complete(r.json()) self.assertEqual(content['path'], '/chunked') r.close() @@ -565,7 +565,7 @@ def test_broken_connection(self): self.assertEqual(r.status, 200) self.assertRaises( aiohttp.IncompleteRead, - self.loop.run_until_complete, r.read(True)) + self.loop.run_until_complete, r.json()) r.close() def test_request_conn_error(self): @@ -593,7 +593,7 @@ def test_keepalive(self): client.request('get', httpd.url('keepalive',), connector=c, loop=self.loop)) self.assertEqual(r.status, 200) - content = self.loop.run_until_complete(r.read(True)) + content = self.loop.run_until_complete(r.json()) self.assertEqual(content['content'], 'requests=1') r.close() @@ -601,7 +601,7 @@ def test_keepalive(self): client.request('get', httpd.url('keepalive'), connector=c, loop=self.loop)) self.assertEqual(r.status, 200) - content = self.loop.run_until_complete(r.read(True)) + content = self.loop.run_until_complete(r.json()) self.assertEqual(content['content'], 'requests=2') r.close() @@ -616,7 +616,7 @@ def test_session_close(self): 'get', httpd.url('keepalive') + '?close=1', connector=conn, loop=self.loop)) self.assertEqual(r.status, 200) - content = self.loop.run_until_complete(r.read(True)) + content = self.loop.run_until_complete(r.json()) self.assertEqual(content['content'], 'requests=1') r.close() @@ -624,7 +624,7 @@ def test_session_close(self): client.request('get', httpd.url('keepalive'), connector=conn, loop=self.loop)) self.assertEqual(r.status, 200) - content = self.loop.run_until_complete(r.read(True)) + content = self.loop.run_until_complete(r.json()) self.assertEqual(content['content'], 'requests=1') r.close() @@ -641,7 +641,7 @@ def test_session_cookies(self, m_log): client.request('get', httpd.url('cookies'), connector=conn, loop=self.loop)) self.assertEqual(r.status, 200) - content = self.loop.run_until_complete(r.read(True)) + content = self.loop.run_until_complete(r.json()) self.assertEqual(content['headers']['Cookie'], 'test=1') r.close() @@ -663,7 +663,7 @@ def test_multidict_headers(self): 'post', url, data=data, headers=MultiDict({'Content-Length': str(len(data))}), loop=self.loop)) - content = self.loop.run_until_complete(r.read(True)) + content = self.loop.run_until_complete(r.json()) r.close() self.assertEqual(str(len(data)), @@ -679,7 +679,7 @@ def go(url): self.assertIsNotNone(connection) connector = connection._connector self.assertIsNotNone(connector) - yield from r.read_and_close() + yield from r.read() self.assertEqual(0, len(connector._conns)) with test_utils.run_server(self.loop, router=Functional) as httpd: @@ -695,7 +695,7 @@ def go(url): r = yield from client.request('GET', url, connector=connector, loop=self.loop) - yield from r.read_and_close() + yield from r.read() self.assertEqual(1, len(connector._conns)) connector.close()