Skip to content

Commit

Permalink
Include 'root_path' when returning URLs from request.url_for (#699)
Browse files Browse the repository at this point in the history
* Include 'root_path' when returning URLs from request.url_for

* Preserve root_path for mounted apps
  • Loading branch information
tomchristie authored Nov 1, 2019
1 parent 1ea45ad commit 2a8c045
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 8 deletions.
5 changes: 3 additions & 2 deletions starlette/datastructures.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(
scheme = scope.get("scheme", "http")
server = scope.get("server", None)
path = scope.get("root_path", "") + scope["path"]
query_string = scope["query_string"]
query_string = scope.get("query_string", b"")

host_header = None
for key, value in scope["headers"]:
Expand Down Expand Up @@ -185,7 +185,8 @@ def make_absolute_url(self, base_url: typing.Union[str, URL]) -> str:
else:
netloc = base_url.netloc

return str(URL(scheme=scheme, netloc=netloc, path=str(self)))
path = base_url.path.rstrip("/") + str(self)
return str(URL(scheme=scheme, netloc=netloc, path=path))


class Secret:
Expand Down
14 changes: 13 additions & 1 deletion starlette/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,18 @@ def url(self) -> URL:
self._url = URL(scope=self.scope)
return self._url

@property
def base_url(self) -> URL:
if not hasattr(self, "_base_url"):
base_url_scope = dict(self.scope)
base_url_scope["path"] = "/"
base_url_scope["query_string"] = b""
base_url_scope["root_path"] = base_url_scope.get(
"app_root_path", base_url_scope.get("root_path", "")
)
self._base_url = URL(scope=base_url_scope)
return self._base_url

@property
def headers(self) -> Headers:
if not hasattr(self, "_headers"):
Expand Down Expand Up @@ -123,7 +135,7 @@ def state(self) -> State:
def url_for(self, name: str, **path_params: typing.Any) -> str:
router = self.scope["router"]
url_path = router.url_path_for(name, **path_params)
return url_path.make_absolute_url(base_url=self.url)
return url_path.make_absolute_url(base_url=self.base_url)


async def empty_receive() -> Message:
Expand Down
4 changes: 3 additions & 1 deletion starlette/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,9 +309,11 @@ def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
matched_path = path[: -len(remaining_path)]
path_params = dict(scope.get("path_params", {}))
path_params.update(matched_params)
root_path = scope.get("root_path", "")
child_scope = {
"path_params": path_params,
"root_path": scope.get("root_path", "") + matched_path,
"app_root_path": scope.get("app_root_path", root_path),
"root_path": root_path + matched_path,
"path": remaining_path,
"endpoint": self.app,
}
Expand Down
14 changes: 10 additions & 4 deletions starlette/testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,12 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:


class _ASGIAdapter(requests.adapters.HTTPAdapter):
def __init__(self, app: ASGI3App, raise_server_exceptions: bool = True) -> None:
def __init__(
self, app: ASGI3App, raise_server_exceptions: bool = True, root_path: str = ""
) -> None:
self.app = app
self.raise_server_exceptions = raise_server_exceptions
self.root_path = root_path

def send( # type: ignore
self, request: requests.PreparedRequest, *args: typing.Any, **kwargs: typing.Any
Expand Down Expand Up @@ -131,7 +134,7 @@ def send( # type: ignore
scope = {
"type": "websocket",
"path": unquote(path),
"root_path": "",
"root_path": self.root_path,
"scheme": scheme,
"query_string": query.encode(),
"headers": headers,
Expand All @@ -147,7 +150,7 @@ def send( # type: ignore
"http_version": "1.1",
"method": request.method,
"path": unquote(path),
"root_path": "",
"root_path": self.root_path,
"scheme": scheme,
"query_string": query.encode(),
"headers": headers,
Expand Down Expand Up @@ -365,6 +368,7 @@ def __init__(
app: typing.Union[ASGI2App, ASGI3App],
base_url: str = "http://testserver",
raise_server_exceptions: bool = True,
root_path: str = "",
) -> None:
super(TestClient, self).__init__()
if _is_asgi3(app):
Expand All @@ -374,7 +378,9 @@ def __init__(
app = typing.cast(ASGI2App, app)
asgi_app = _WrapASGI2(app) #  type: ignore
adapter = _ASGIAdapter(
asgi_app, raise_server_exceptions=raise_server_exceptions
asgi_app,
raise_server_exceptions=raise_server_exceptions,
root_path=root_path,
)
self.mount("http://", adapter)
self.mount("https://", adapter)
Expand Down
47 changes: 47 additions & 0 deletions tests/test_routing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest

from starlette.applications import Starlette
from starlette.responses import JSONResponse, PlainTextResponse, Response
from starlette.routing import Host, Mount, NoMatchFound, Route, Router, WebSocketRoute
from starlette.testclient import TestClient
Expand Down Expand Up @@ -164,12 +165,24 @@ def test_url_for():
app.url_path_for("homepage").make_absolute_url(base_url="https://example.org")
== "https://example.org/"
)
assert (
app.url_path_for("homepage").make_absolute_url(
base_url="https://example.org/root_path/"
)
== "https://example.org/root_path/"
)
assert (
app.url_path_for("user", username="tomchristie").make_absolute_url(
base_url="https://example.org"
)
== "https://example.org/users/tomchristie"
)
assert (
app.url_path_for("user", username="tomchristie").make_absolute_url(
base_url="https://example.org/root_path/"
)
== "https://example.org/root_path/users/tomchristie"
)
assert (
app.url_path_for("websocket_endpoint").make_absolute_url(
base_url="https://example.org"
Expand Down Expand Up @@ -353,3 +366,37 @@ def test_subdomain_reverse_urls():
).make_absolute_url("https://whatever")
== "https://foo.example.org/homepage"
)


async def echo_urls(request):
return JSONResponse(
{
"index": request.url_for("index"),
"submount": request.url_for("mount:submount"),
}
)


echo_url_routes = [
Route("/", echo_urls, name="index", methods=["GET"]),
Mount(
"/submount",
name="mount",
routes=[Route("/", echo_urls, name="submount", methods=["GET"])],
),
]


def test_url_for_with_root_path():
app = Starlette(routes=echo_url_routes)
client = TestClient(app, base_url="https://www.example.org/", root_path="/sub_path")
response = client.get("/")
assert response.json() == {
"index": "https://www.example.org/sub_path/",
"submount": "https://www.example.org/sub_path/submount/",
}
response = client.get("/submount/")
assert response.json() == {
"index": "https://www.example.org/sub_path/",
"submount": "https://www.example.org/sub_path/submount/",
}

0 comments on commit 2a8c045

Please sign in to comment.