Skip to content

Commit

Permalink
Use yarl in client API (#1217)
Browse files Browse the repository at this point in the history
* Use yarl in client API

* Update doc and CHANGES

* Add a test
  • Loading branch information
asvetlov authored Sep 27, 2016
1 parent 0d41462 commit e81a478
Show file tree
Hide file tree
Showing 22 changed files with 287 additions and 358 deletions.
3 changes: 2 additions & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ CHANGES

- Drop deprecated `WSClientDisconnectedError` (BACKWARD INCOMPATIBLE)

-
- Use `yarl.URL` in client API. The change is 99% backward compatible
but `ClientResponse.url` is an `yarl.URL` instance now.

-

Expand Down
4 changes: 4 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,19 @@ Requirements
------------

- Python >= 3.4.2
- asyncio-timeout_
- chardet_
- multidict_
- yarl_

Optionally you may install the cChardet_ and aiodns_ libraries (highly
recommended for sake of speed).

.. _chardet: https://pypi.python.org/pypi/chardet
.. _aiodns: https://pypi.python.org/pypi/aiodns
.. _multidict: https://pypi.python.org/pypi/multidict
.. _yarl: https://pypi.python.org/pypi/yarl
.. _asyncio-timeout: https://pypi.python.org/pypi/asyncio_timeout
.. _cChardet: https://pypi.python.org/pypi/cchardet

License
Expand Down
15 changes: 9 additions & 6 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
import os
import sys
import traceback
import urllib.parse
import warnings

from multidict import CIMultiDict, MultiDict, MultiDictProxy, istr
from yarl import URL

import aiohttp

Expand Down Expand Up @@ -180,8 +180,11 @@ def _request(self, method, url, *,
for i in skip_auto_headers:
skip_headers.add(istr(i))

if isinstance(proxy, str):
proxy = URL(proxy)

while True:
url, _ = urllib.parse.urldefrag(url)
url = URL(url).with_fragment(None)

cookies = self._cookie_jar.filter_cookies(url)

Expand Down Expand Up @@ -237,15 +240,15 @@ def _request(self, method, url, *,
if headers.get(hdrs.CONTENT_LENGTH):
headers.pop(hdrs.CONTENT_LENGTH)

r_url = (resp.headers.get(hdrs.LOCATION) or
resp.headers.get(hdrs.URI))
r_url = URL(resp.headers.get(hdrs.LOCATION) or
resp.headers.get(hdrs.URI))

scheme = urllib.parse.urlsplit(r_url)[0]
scheme = r_url.scheme
if scheme not in ('http', 'https', ''):
resp.close()
raise ValueError('Can redirect only to http or https')
elif not scheme:
r_url = urllib.parse.urljoin(url, r_url)
r_url = url.join(r_url)

url = r_url
params = None
Expand Down
107 changes: 35 additions & 72 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import asyncio
import collections
import http.cookies
import io
import json
import mimetypes
import os
import sys
import traceback
import urllib.parse
import warnings

from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy
from yarl import URL

import aiohttp

Expand All @@ -31,9 +30,6 @@

PY_35 = sys.version_info >= (3, 5)

HTTP_PORT = 80
HTTPS_PORT = 443


class ClientRequest:

Expand Down Expand Up @@ -75,7 +71,15 @@ def __init__(self, method, url, *,
if loop is None:
loop = asyncio.get_event_loop()

self.url = url
assert isinstance(url, URL), url
assert isinstance(proxy, (URL, type(None))), proxy

if params:
q = MultiDict(url.query)
url2 = url.with_query(params)
q.extend(url2.query)
url = url.with_query(q)
self.url = url.with_fragment(None)
self.method = method.upper()
self.encoding = encoding
self.chunked = chunked
Expand All @@ -89,7 +93,6 @@ def __init__(self, method, url, *,

self.update_version(version)
self.update_host(url)
self.update_path(params)
self.update_headers(headers)
self.update_auto_headers(skip_auto_headers)
self.update_cookies(cookies)
Expand All @@ -101,59 +104,30 @@ def __init__(self, method, url, *,
self.update_transfer_encoding()
self.update_expect_continue(expect100)

def update_host(self, url):
"""Update destination host, port and connection type (ssl)."""
url_parsed = urllib.parse.urlsplit(url)
@property
def host(self):
return self.url.host

# check for network location part
netloc = url_parsed.netloc
if not netloc:
raise ValueError('Host could not be detected.')
@property
def port(self):
return self.url.port

def update_host(self, url):
"""Update destination host, port and connection type (ssl)."""
# get host/port
host = url_parsed.hostname
if not host:
if not url.host:
raise ValueError('Host could not be detected.')

try:
port = url_parsed.port
except ValueError:
raise ValueError(
'Port number could not be converted.') from None

# check domain idna encoding
try:
host = host.encode('idna').decode('utf-8')
netloc = self.make_netloc(host, url_parsed.port)
except UnicodeError:
raise ValueError('URL has an invalid label.')

# basic auth info
username, password = url_parsed.username, url_parsed.password
username, password = url.user, url.password
if username:
self.auth = helpers.BasicAuth(username, password or '')

# Record entire netloc for usage in host header
self.netloc = netloc

scheme = url_parsed.scheme
scheme = url.scheme
self.ssl = scheme in ('https', 'wss')

# set port number if it isn't already set
if not port:
if self.ssl:
port = HTTPS_PORT
else:
port = HTTP_PORT

self.host, self.port, self.scheme = host, port, scheme

def make_netloc(self, host, port):
ret = host
if port:
ret = ret + ':' + str(port)
return ret

def update_version(self, version):
"""Convert request version to two elements tuple.
Expand All @@ -172,25 +146,8 @@ def update_version(self, version):
def update_path(self, params):
"""Build path."""
# extract path
scheme, netloc, path, query, fragment = urllib.parse.urlsplit(self.url)
if not path:
path = '/'

if isinstance(params, collections.Mapping):
params = list(params.items())

if params:
if not isinstance(params, str):
params = urllib.parse.urlencode(params)
if query:
query = '%s&%s' % (query, params)
else:
query = params

self.path = urllib.parse.urlunsplit(('', '', helpers.requote_uri(path),
query, ''))
self.url = urllib.parse.urlunsplit(
(scheme, netloc, self.path, '', fragment))
self.url = self.url.with_query(params)

def update_headers(self, headers):
"""Update request headers."""
Expand All @@ -214,7 +171,10 @@ def update_auto_headers(self, skip_auto_headers):

# add host
if hdrs.HOST not in used_headers:
self.headers[hdrs.HOST] = self.netloc
netloc = self.url.host
if not self.url.is_default_port():
netloc += ':' + str(self.url.port)
self.headers[hdrs.HOST] = netloc

if hdrs.USER_AGENT not in used_headers:
self.headers[hdrs.USER_AGENT] = self.SERVER_SOFTWARE
Expand Down Expand Up @@ -378,7 +338,7 @@ def update_expect_continue(self, expect=False):
self._continue = helpers.create_future(self.loop)

def update_proxy(self, proxy, proxy_auth):
if proxy and not proxy.startswith('http://'):
if proxy and not proxy.scheme == 'http':
raise ValueError("Only http proxies are supported")
if proxy_auth and not isinstance(proxy_auth, helpers.BasicAuth):
raise ValueError("proxy_auth must be None or BasicAuth() tuple")
Expand Down Expand Up @@ -488,7 +448,11 @@ def write_bytes(self, request, reader):

def send(self, writer, reader):
writer.set_tcp_cork(True)
request = aiohttp.Request(writer, self.method, self.path, self.version)
path = self.url.raw_path
if self.url.raw_query_string:
path += '?' + self.url.raw_query_string
request = aiohttp.Request(writer, self.method, path,
self.version)

if self.compress:
request.add_compression_filter(self.compress)
Expand All @@ -511,7 +475,7 @@ def send(self, writer, reader):
self.write_bytes(request, reader), loop=self.loop)

self.response = self.response_class(
self.method, self.url, self.host,
self.method, self.url, self.url.host,
writer=self._writer, continue100=self._continue,
timeout=self._timeout)
self.response._post_init(self.loop)
Expand Down Expand Up @@ -556,7 +520,7 @@ class ClientResponse:

def __init__(self, method, url, host='', *, writer=None, continue100=None,
timeout=5*60):
super().__init__()
assert isinstance(url, URL)

self.method = method
self.url = url
Expand Down Expand Up @@ -591,8 +555,7 @@ def __del__(self, _warnings=warnings):

def __repr__(self):
out = io.StringIO()
ascii_encodable_url = self.url.encode('ascii', 'backslashreplace') \
.decode('ascii')
ascii_encodable_url = str(self.url)
if self.reason:
ascii_encodable_reason = self.reason.encode('ascii',
'backslashreplace') \
Expand Down
7 changes: 4 additions & 3 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from math import ceil
from types import MappingProxyType

from yarl import URL

import aiohttp

from . import hdrs, helpers
Expand Down Expand Up @@ -638,9 +640,7 @@ def _create_proxy_connection(self, req):
raise ProxyConnectionError(*exc.args) from exc

if not req.ssl:
req.path = '{scheme}://{host}{path}'.format(scheme=req.scheme,
host=req.netloc,
path=req.path)
req.path = str(req.url)
if hdrs.AUTHORIZATION in proxy_req.headers:
auth = proxy_req.headers[hdrs.AUTHORIZATION]
del proxy_req.headers[hdrs.AUTHORIZATION]
Expand Down Expand Up @@ -723,6 +723,7 @@ def __init__(self, proxy, *, proxy_auth=None, force_close=True,
conn_timeout=conn_timeout,
keepalive_timeout=keepalive_timeout,
limit=limit, loop=loop)
assert isinstance(proxy, URL)
self._proxy = proxy
self._proxy_auth = proxy_auth

Expand Down
19 changes: 9 additions & 10 deletions aiohttp/cookiejar.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from collections.abc import Mapping
from http.cookies import Morsel, SimpleCookie
from math import ceil
from urllib.parse import urlsplit

from yarl import URL

from .abc import AbstractCookieJar
from .helpers import is_ip_address
Expand Down Expand Up @@ -76,10 +77,9 @@ def _expire_cookie(self, when, domain, name):
self._next_expiration = min(self._next_expiration, when)
self._expirations[(domain, name)] = when

def update_cookies(self, cookies, response_url=None):
def update_cookies(self, cookies, response_url=URL()):
"""Update cookies."""
url_parsed = urlsplit(response_url or "")
hostname = url_parsed.hostname
hostname = response_url.host

if not self._unsafe and is_ip_address(hostname):
# Don't accept cookies from IPs
Expand Down Expand Up @@ -119,7 +119,7 @@ def update_cookies(self, cookies, response_url=None):
path = cookie["path"]
if not path or not path.startswith("/"):
# Set the cookie's path to the response path
path = url_parsed.path
path = response_url.path
if not path.startswith("/"):
path = "/"
else:
Expand Down Expand Up @@ -152,13 +152,12 @@ def update_cookies(self, cookies, response_url=None):

self._do_expiration()

def filter_cookies(self, request_url):
def filter_cookies(self, request_url=URL()):
"""Returns this jar's cookies filtered by their attributes."""
self._do_expiration()
url_parsed = urlsplit(request_url)
filtered = SimpleCookie()
hostname = url_parsed.hostname or ""
is_not_secure = url_parsed.scheme not in ("https", "wss")
hostname = request_url.host or ""
is_not_secure = request_url.scheme not in ("https", "wss")

for cookie in self:
name = cookie.key
Expand All @@ -178,7 +177,7 @@ def filter_cookies(self, request_url):
elif not self._is_domain_match(domain, hostname):
continue

if not self._is_path_match(url_parsed.path, cookie["path"]):
if not self._is_path_match(request_url.path, cookie["path"]):
continue

if is_not_secure and cookie["secure"]:
Expand Down
4 changes: 0 additions & 4 deletions aiohttp/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,3 @@ def __repr__(self):
return '<{} expected={} got={} host={} port={}>'.format(
self.__class__.__name__, self.expected, self.got,
self.host, self.port)


class InvalidURL(Exception):
"""Invalid URL."""
Loading

0 comments on commit e81a478

Please sign in to comment.