diff --git a/CHANGES/1039.bugfix.rst b/CHANGES/1039.bugfix.rst new file mode 100644 index 000000000..f1e61309a --- /dev/null +++ b/CHANGES/1039.bugfix.rst @@ -0,0 +1,27 @@ +:meth:`URL.join() ` has been changed to match +:rfc:`3986` and align with +:meth:`/ operation ` and :meth:`URL.joinpath() ` +when joining URLs with empty segments. +Previously :py:func:`urllib.parse.urljoin` was used, +which has known issues with empty segments +(`python/cpython#84774 `_). + +Due to the semantics of :meth:`URL.join() `, joining an +URL with scheme requires making it relative, prefixing with ``./``. + +.. code-block:: pycon + + >>> URL("https://web.archive.org/web/").join(URL("./https://github.com/aio-libs/yarl")) + URL('https://web.archive.org/web/https://github.com/aio-libs/yarl') + + +Empty segments are honored in the base as well as the joined part. + +.. code-block:: pycon + + >>> URL("https://web.archive.org/web/https://").join(URL("github.com/aio-libs/yarl")) + URL('https://web.archive.org/web/https://github.com/aio-libs/yarl') + + + +-- by :user:`commonism` diff --git a/tests/test_url.py b/tests/test_url.py index d62cd81af..10c83173f 100644 --- a/tests/test_url.py +++ b/tests/test_url.py @@ -1664,6 +1664,69 @@ def test_join_from_rfc_3986_abnormal(url, expected): assert base.join(url) == expected +EMPTY_SEGMENTS = [ + ( + "https://web.archive.org/web/", + "./https://github.com/aio-libs/yarl", + "https://web.archive.org/web/https://github.com/aio-libs/yarl", + ), + ( + "https://web.archive.org/web/https://github.com/", + "aio-libs/yarl", + "https://web.archive.org/web/https://github.com/aio-libs/yarl", + ), +] + + +@pytest.mark.parametrize("base,url,expected", EMPTY_SEGMENTS) +def test_join_empty_segments(base, url, expected): + base = URL(base) + url = URL(url) + expected = URL(expected) + joined = base.join(url) + assert joined == expected + + +SIMPLE_BASE = "http://a/b/c/d" +URLLIB_URLJOIN = [ + ("", "http://a/b/c/g?y/./x", "http://a/b/c/g?y/./x"), + ("", "http://a/./g", "http://a/./g"), + ("svn://pathtorepo/dir1", "dir2", "svn://pathtorepo/dir2"), + ("svn+ssh://pathtorepo/dir1", "dir2", "svn+ssh://pathtorepo/dir2"), + ("ws://a/b", "g", "ws://a/g"), + ("wss://a/b", "g", "wss://a/g"), + # test for issue22118 duplicate slashes + (SIMPLE_BASE + "/", "foo", SIMPLE_BASE + "/foo"), + # Non-RFC-defined tests, covering variations of base and trailing + # slashes + ("http://a/b/c/d/e/", "../../f/g/", "http://a/b/c/f/g/"), + ("http://a/b/c/d/e", "../../f/g/", "http://a/b/f/g/"), + ("http://a/b/c/d/e/", "/../../f/g/", "http://a/f/g/"), + ("http://a/b/c/d/e", "/../../f/g/", "http://a/f/g/"), + ("http://a/b/c/d/e/", "../../f/g", "http://a/b/c/f/g"), + ("http://a/b/", "../../f/g/", "http://a/f/g/"), + ("a", "b", "b"), +] + + +@pytest.mark.parametrize("base,url,expected", URLLIB_URLJOIN) +def test_join_cpython_urljoin(base, url, expected): + # tests from cpython urljoin + base = URL(base) + url = URL(url) + expected = URL(expected) + joined = base.join(url) + assert joined == expected + + +def test_join_cpython_urljoin_fail(): + with pytest.raises( + TypeError, match=r"unsupported operand type\(s\) for \+: 'NoneType' and 'str'" + ): + URL("http:///").join(URL("..")) + pytest.xfail("Shouldn't raise TypeError on empty host name") + + def test_split_result_non_decoded(): with pytest.raises(ValueError): URL(SplitResult("http", "example.com", "path", "qs", "frag")) diff --git a/yarl/_url.py b/yarl/_url.py index 78cc5a827..36da82ba1 100644 --- a/yarl/_url.py +++ b/yarl/_url.py @@ -6,7 +6,9 @@ from contextlib import suppress from ipaddress import ip_address from typing import Union -from urllib.parse import SplitResult, parse_qsl, quote, urljoin, urlsplit, urlunsplit +from urllib.parse import SplitResult, parse_qsl, quote, urlsplit, urlunsplit +from urllib.parse import uses_netloc as uses_authority +from urllib.parse import uses_relative import idna from multidict import MultiDict, MultiDictProxy @@ -259,12 +261,14 @@ def build( netloc = authority else: tmp = SplitResult("", authority, "", "", "") + port = None if tmp.port == cls._default_port(scheme) else tmp.port netloc = cls._make_netloc( - tmp.username, tmp.password, tmp.hostname, tmp.port, encode=True + tmp.username, tmp.password, tmp.hostname, port, encode=True ) elif not user and not password and not host and not port: netloc = "" else: + port = None if port == cls._default_port(scheme) else port netloc = cls._make_netloc( user, password, host, port, encode=not encoded, encode_host=not encoded ) @@ -455,12 +459,15 @@ def raw_authority(self): def _get_default_port(self) -> Union[int, None]: if not self.scheme: return None + return self._default_port(self.scheme) + @staticmethod + def _default_port(scheme: str) -> Union[int, None]: with suppress(KeyError): - return DEFAULT_PORTS[self.scheme] + return DEFAULT_PORTS[scheme] with suppress(OSError): - return socket.getservbyname(self.scheme) + return socket.getservbyname(scheme) return None @@ -762,11 +769,7 @@ def _make_child(self, paths, encoded=False): f"Appending path {path!r} starting from slash is forbidden" ) path = path if encoded else self._PATH_QUOTER(path) - segments = [ - segment for segment in reversed(path.split("/")) if segment != "." - ] - if not segments: - continue + segments = list(reversed(path.split("/"))) # remove trailing empty segment for all but the last path segment_slice_start = int(not last and segments[0] == "") parsed += segments[segment_slice_start:] @@ -1153,7 +1156,52 @@ def join(self, url): # See docs for urllib.parse.urljoin if not isinstance(url, URL): raise TypeError("url should be URL") - return URL(urljoin(str(self), str(url)), encoded=True) + other: URL = url + scheme = other.scheme or self.scheme + parts = { + k: getattr(self, k) or "" + for k in ("authority", "path", "query_string", "fragment") + } + parts["scheme"] = scheme + + if scheme != self.scheme or scheme not in uses_relative: + return URL(str(other)) + + # scheme is in uses_authority as uses_authority is a superset of uses_relative + if scheme in uses_authority and other.authority: + parts.update( + { + k: getattr(other, k) or "" + for k in ("authority", "path", "query_string", "fragment") + } + ) + return URL.build(**parts) + + if other.path or other.fragment: + parts["fragment"] = other.fragment or "" + if other.path or other.query: + parts["query_string"] = other.query_string or "" + + if not other.path: + return URL.build(**parts) + + if other.path[0] == "/": + parts["path"] = other.path + return URL.build(**parts) + + if self.path[-1] == "/": + # using an intermediate to avoid URL.joinpath dropping query & fragment + parts["path"] = URL(self.path).joinpath(other.path).path + else: + # … + # and relativizing ".." + path = URL("/".join([*self.parts[1:-1], ""])).joinpath(other.path).path + if parts["authority"]: + parts["path"] = "/" + path + else: + parts["path"] = path + + return URL.build(**parts) def joinpath(self, *other, encoded=False): """Return a new URL with the elements in other appended to the path."""