Skip to content

Commit

Permalink
Fix redirect to /ui behind proxy
Browse files Browse the repository at this point in the history
  • Loading branch information
findepi committed Feb 24, 2020
1 parent 0e6980b commit 4ab76c2
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,8 @@

import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.net.URISyntaxException;

import static com.google.common.base.MoreObjects.firstNonNull;
import static com.google.common.base.Strings.emptyToNull;
import static com.google.common.net.HttpHeaders.LOCATION;
import static com.google.common.net.HttpHeaders.X_FORWARDED_PROTO;
import static javax.servlet.http.HttpServletResponse.SC_SEE_OTHER;
import static javax.servlet.http.HttpServletResponse.SC_UNAUTHORIZED;

Expand Down Expand Up @@ -78,13 +73,6 @@ private static boolean isPublic(HttpServletRequest request)

private static String getRedirectLocation(HttpServletRequest request, String path)
{
try {
String proto = firstNonNull(emptyToNull(request.getHeader(X_FORWARDED_PROTO)), request.getScheme());
URI baseServerLocation = new URI(proto, null, request.getServerName(), request.getServerPort(), null, null, null);
return baseServerLocation.toASCIIString() + path;
}
catch (URISyntaxException e) {
throw new RuntimeException(e);
}
return FormWebUiAuthenticationManager.getRedirectLocation(request, path, null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,12 @@
import static com.google.common.base.Strings.isNullOrEmpty;
import static com.google.common.net.HttpHeaders.LOCATION;
import static com.google.common.net.HttpHeaders.WWW_AUTHENTICATE;
import static com.google.common.net.HttpHeaders.X_FORWARDED_HOST;
import static com.google.common.net.HttpHeaders.X_FORWARDED_PORT;
import static com.google.common.net.HttpHeaders.X_FORWARDED_PROTO;
import static io.airlift.http.client.HttpUriBuilder.uriBuilder;
import static io.prestosql.server.HttpRequestSessionContext.AUTHENTICATED_IDENTITY;
import static java.lang.Integer.parseInt;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Arrays.stream;
import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -313,14 +316,25 @@ private static String getRedirectLocation(HttpServletRequest request, String pat
return getRedirectLocation(request, path, null);
}

private static String getRedirectLocation(HttpServletRequest request, String path, String queryParameter)
static String getRedirectLocation(HttpServletRequest request, String path, String queryParameter)
{
String proto = firstNonNull(emptyToNull(request.getHeader(X_FORWARDED_PROTO)), request.getScheme());
HttpUriBuilder builder = uriBuilder()
.scheme(proto)
.host(request.getServerName())
.port(request.getServerPort())
.replacePath(path);
HttpUriBuilder builder;
if (isNullOrEmpty(request.getHeader(X_FORWARDED_HOST))) {
// not forwarded
builder = uriBuilder()
.scheme(request.getScheme())
.host(request.getServerName())
.port(request.getServerPort());
}
else {
// forwarded
builder = uriBuilder()
.scheme(firstNonNull(emptyToNull(request.getHeader(X_FORWARDED_PROTO)), request.getScheme()))
.host(request.getHeader(X_FORWARDED_HOST));
getForwarderPort(request).ifPresent(builder::port);
}

builder.replacePath(path);
if (queryParameter != null) {
builder.addParameter(queryParameter);
}
Expand Down Expand Up @@ -356,4 +370,16 @@ private static String parseJwt(byte[] hmac, String jwt)
.getBody()
.getSubject();
}

private static Optional<Integer> getForwarderPort(HttpServletRequest request)
{
if (!isNullOrEmpty(request.getHeader(X_FORWARDED_PORT))) {
try {
return Optional.of(parseInt(request.getHeader(X_FORWARDED_PORT)));
}
catch (ArithmeticException ignore) {
}
}
return Optional.empty();
}
}
32 changes: 30 additions & 2 deletions presto-main/src/test/java/io/prestosql/server/TestServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import io.airlift.http.client.HttpClient;
import io.airlift.http.client.HttpUriBuilder;
import io.airlift.http.client.Request;
import io.airlift.http.client.StatusResponseHandler;
import io.airlift.http.client.StatusResponseHandler.StatusResponse;
import io.airlift.http.client.jetty.JettyHttpClient;
import io.airlift.json.JsonCodec;
import io.prestosql.client.QueryError;
Expand All @@ -43,6 +43,9 @@
import java.util.stream.Stream;

import static com.google.common.base.Preconditions.checkState;
import static com.google.common.net.HttpHeaders.X_FORWARDED_HOST;
import static com.google.common.net.HttpHeaders.X_FORWARDED_PORT;
import static com.google.common.net.HttpHeaders.X_FORWARDED_PROTO;
import static io.airlift.http.client.FullJsonResponseHandler.createFullJsonResponseHandler;
import static io.airlift.http.client.Request.Builder.prepareGet;
import static io.airlift.http.client.Request.Builder.preparePost;
Expand All @@ -68,6 +71,7 @@
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Objects.requireNonNull;
import static javax.ws.rs.core.Response.Status.OK;
import static javax.ws.rs.core.Response.Status.SEE_OTHER;
import static org.assertj.core.api.Assertions.assertThat;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertNotNull;
Expand Down Expand Up @@ -118,7 +122,7 @@ public void testInvalidSessionError()
@Test
public void testServerStarts()
{
StatusResponseHandler.StatusResponse response = client.execute(
StatusResponse response = client.execute(
prepareGet().setUri(server.resolve("/v1/info")).build(),
createStatusResponseHandler());

Expand Down Expand Up @@ -221,6 +225,30 @@ public Object[][] testVersionOnErrorDataProvider()
};
}

@Test
public void testRedirectToUi()
{
Request request = prepareGet()
.setUri(uriFor("/"))
.setFollowRedirects(false)
.build();
StatusResponse response = client.execute(request, createStatusResponseHandler());
assertEquals(response.getStatusCode(), SEE_OTHER.getStatusCode(), "Status code");
assertEquals(response.getHeader("Location"), server.getBaseUrl() + "/ui/", "Location");

// behind a proxy
request = prepareGet()
.setUri(uriFor("/"))
.setHeader(X_FORWARDED_PROTO, "https")
.setHeader(X_FORWARDED_HOST, "my-load-balancer.local")
.setHeader(X_FORWARDED_PORT, "443")
.setFollowRedirects(false)
.build();
response = client.execute(request, createStatusResponseHandler());
assertEquals(response.getStatusCode(), SEE_OTHER.getStatusCode(), "Status code");
assertEquals(response.getHeader("Location"), "https://my-load-balancer.local:443/ui/", "Location");
}

private Stream<JsonResponse<QueryResults>> postQuery(Function<Request.Builder, Request.Builder> requestConfigurer)
{
Request.Builder request = preparePost()
Expand Down
33 changes: 30 additions & 3 deletions presto-main/src/test/java/io/prestosql/server/ui/TestWebUi.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@

import static com.google.common.collect.Iterables.getOnlyElement;
import static com.google.common.net.HttpHeaders.LOCATION;
import static com.google.common.net.HttpHeaders.X_FORWARDED_HOST;
import static com.google.common.net.HttpHeaders.X_FORWARDED_PORT;
import static com.google.common.net.HttpHeaders.X_FORWARDED_PROTO;
import static io.airlift.http.client.HttpUriBuilder.uriBuilderFrom;
import static io.airlift.testing.Closeables.closeQuietly;
import static io.prestosql.client.OkHttpUtil.setupSsl;
Expand Down Expand Up @@ -168,7 +171,7 @@ private void testLogIn(URI baseUri)
assertResponseCode(client, getLocation(baseUri, "/ui/unknown"), SC_NOT_FOUND);

assertResponseCode(client, getLocation(baseUri, "/ui/api/unknown"), SC_NOT_FOUND);
assertRedirect(client, getLogoutLocation(baseUri), getLoginHtmlLocation(baseUri));
assertRedirect(client, getLogoutLocation(baseUri), getLoginHtmlLocation(baseUri), false);
assertThat(cookieManager.getCookieStore().getCookies()).isEmpty();
}

Expand Down Expand Up @@ -323,8 +326,32 @@ private static Response assertOk(OkHttpClient client, String url)
private static void assertRedirect(OkHttpClient client, String url, String redirectLocation)
throws IOException
{
Response response = assertResponseCode(client, url, SC_SEE_OTHER);
assertEquals(response.header(LOCATION), redirectLocation);
assertRedirect(client, url, redirectLocation, true);
}

private static void assertRedirect(OkHttpClient client, String url, String redirectLocation, boolean testProxy)
throws IOException
{
Request request = new Request.Builder()
.url(url)
.build();
try (Response response = client.newCall(request).execute()) {
assertEquals(response.code(), SC_SEE_OTHER);
assertEquals(response.header(LOCATION), redirectLocation);
}

if (testProxy) {
request = new Request.Builder()
.url(url)
.header(X_FORWARDED_PROTO, "https")
.header(X_FORWARDED_HOST, "my-load-balancer.local")
.header(X_FORWARDED_PORT, "443")
.build();
try (Response response = client.newCall(request).execute()) {
assertEquals(response.code(), SC_SEE_OTHER);
assertEquals(response.header(LOCATION), "https://my-load-balancer.local:443/" + redirectLocation.replaceFirst("^([^/]*/){3}", ""));
}
}
}

private static Response assertResponseCode(OkHttpClient client, String url, int expectedCode)
Expand Down

0 comments on commit 4ab76c2

Please sign in to comment.